From cdabcf61f33def061dcdb3d648f223ac18f800c0 Mon Sep 17 00:00:00 2001 From: Nathan Anderson Date: Thu, 27 Jul 2023 01:23:10 -0600 Subject: [PATCH] Added authentication and shared notes endpoints --- src/db/db.zig | 119 +++++++++++-- src/db/models.zig | 121 +++++++++++-- src/http_handler.zig | 28 +-- src/routes/auth.zig | 90 ++++++++++ src/routes/budget.zig | 236 +++++++++++++------------ src/routes/dashboard.zig | 76 +++++--- src/routes/shared_note.zig | 155 +++++++++++++++++ src/routes/transactions.zig | 102 +++++++---- src/routes/user.zig | 333 ++++++++++++++++++++++++++++-------- src/utils.zig | 23 +++ 10 files changed, 1014 insertions(+), 269 deletions(-) create mode 100644 src/routes/shared_note.zig diff --git a/src/db/db.zig b/src/db/db.zig index 2b68683..98e58c4 100644 --- a/src/db/db.zig +++ b/src/db/db.zig @@ -31,21 +31,85 @@ pub const Db = struct { self: *Db, comptime Type: type, allocator: Allocator, - comptime whereClause: []const u8, + comptime where_clause: []const u8, values: anytype, - comptime limit: ?u32, + comptime order_by_field: ?[]const u8, + comptime order: ?[]const u8, + comptime limit: ?bool, + limit_val: ?u32, ) !?[]Type { - _ = limit; - + comptime { + if (order == null and order_by_field != null or order != null and order_by_field == null) { + @compileError("Must provide both order and order_by or neither, with select " ++ where_clause); + } + if (order != null) { + if (!(std.mem.eql(u8, order.?, "DESC") or std.mem.eql(u8, order.?, "ASC"))) { + @compileError("Must use ASC or DESC for order_by, used: " ++ order.?); + } + } + if (!std.mem.containsAtLeast(u8, where_clause, 1, "?")) { + @compileError("where_clause missing '?', no possible values to insert " ++ where_clause); + } + // Check that where_clause only contains fields in struct + var query_objs_iter = std.mem.split(u8, where_clause, "?"); + inline for (@typeInfo(@TypeOf(values)).Struct.fields) |struct_field| { + const name = struct_field.name; + const query_obj = query_objs_iter.next(); + if (query_obj == null) { + @compileError("Query does not have enough clauses for passed in data: " ++ where_clause); + } + if (std.mem.containsAtLeast(u8, query_obj.?, 1, name)) { + continue; + } else { + @compileError("Missing field or messed up order in select:\n" ++ where_clause ++ "\n" ++ query_obj.?); + } + } + const last = query_objs_iter.next(); + if (last != null and !std.mem.eql(u8, last.?, "")) { + @compileError("Values is lacking, query contains more ? than data provided: " ++ where_clause ++ "\nLeft with: " ++ last.?); + } + } + if (limit == null and limit_val != null or limit != null and limit_val == null) { + std.log.err("Must provide both limit and limit_val or neither, with select: {s}", .{where_clause}); + return null; + } var res_array: std.ArrayList(Type) = std.ArrayList(Type).init(allocator); + if (order_by_field == null and limit == null) { + const query = "SELECT * FROM " ++ models.getTypeTableName(Type) ++ " " ++ where_clause ++ ";"; + var stmt = try self._sql_db.prepare(query); + defer stmt.deinit(); - const query = "SELECT * FROM " ++ models.getTypeTableName(Type) ++ " " ++ whereClause ++ ";"; - var stmt = try self._sql_db.prepare(query); - defer stmt.deinit(); + var iter = try stmt.iteratorAlloc(Type, allocator, values); + while (try iter.nextAlloc(allocator, .{})) |row| { + try res_array.append(row); + } + } else if (order_by_field == null and limit != null) { + const query = "SELECT * FROM " ++ models.getTypeTableName(Type) ++ " " ++ where_clause ++ " LIMIT ?;"; + var stmt = try self._sql_db.prepare(query); + defer stmt.deinit(); - var iter = try stmt.iteratorAlloc(Type, allocator, values); - while (try iter.nextAlloc(allocator, .{})) |row| { - try res_array.append(row); + var iter = try stmt.iteratorAlloc(Type, allocator, utils.structConcatFields(values, .{ .limit = limit_val.? })); + while (try iter.nextAlloc(allocator, .{})) |row| { + try res_array.append(row); + } + } else if (order_by_field != null and limit == null) { + const query = "SELECT * FROM " ++ models.getTypeTableName(Type) ++ " " ++ where_clause ++ " ORDER BY " ++ order_by_field.? ++ " " ++ order.? ++ ";"; + var stmt = try self._sql_db.prepare(query); + defer stmt.deinit(); + + var iter = try stmt.iteratorAlloc(Type, allocator, values); + while (try iter.nextAlloc(allocator, .{})) |row| { + try res_array.append(row); + } + } else { + const query = "SELECT * FROM " ++ models.getTypeTableName(Type) ++ " " ++ where_clause ++ " ORDER BY " ++ order_by_field.? ++ " " ++ order.? ++ " LIMIT ?;"; + var stmt = try self._sql_db.prepare(query); + defer stmt.deinit(); + + var iter = try stmt.iteratorAlloc(Type, allocator, utils.structConcatFields(values, .{ .limit = limit_val.? })); + while (try iter.nextAlloc(allocator, .{})) |row| { + try res_array.append(row); + } } return try res_array.toOwnedSlice(); @@ -57,6 +121,21 @@ pub const Db = struct { } pub fn selectOne(self: *Db, comptime Type: type, allocator: Allocator, comptime query: []const u8, values: anytype) !?Type { + comptime { + var query_objs_iter = std.mem.split(u8, query, "="); + inline for (@typeInfo(@TypeOf(values)).Struct.fields) |struct_field| { + const name = struct_field.name; + const query_obj = query_objs_iter.next(); + if (query_obj == null) { + @compileError("Query does not have enough clauses for passed in data:\n" ++ "Type: " ++ @typeName(Type) ++ "\n" ++ query ++ "\n"); + } + if (std.mem.containsAtLeast(u8, query_obj.?, 1, name)) { + continue; + } else { + @compileError("Missing field or messed up order in select:\n" ++ query ++ "\n" ++ query_obj.?); + } + } + } const row = try self._sql_db.oneAlloc(Type, allocator, query, .{}, values); // std.debug.print("{any}", .{row}); return row; @@ -72,9 +151,25 @@ pub const Db = struct { } pub fn insert(self: *Db, comptime Type: type, values: anytype) !void { - // TODO check there is an ID field + comptime { + const query = models.createInsertQuery(Type); + // const InsertType = utils.removeStructFields(Type, &[_]u8{0}); + var query_objs_iter = std.mem.split(u8, query, ","); + inline for (@typeInfo(@TypeOf(values)).Struct.fields) |struct_field| { + const name = struct_field.name; + const query_obj = query_objs_iter.next(); + if (query_obj == null) { + @compileError("Query does not have enough clauses for passed in data:\n" ++ "Type: " ++ @typeName(Type) ++ "\n" ++ query ++ "\n"); + } + if (std.mem.containsAtLeast(u8, query_obj.?, 1, name)) { + continue; + } else { + @compileError("Missing field or messed up order in insert:\n" ++ query ++ "\n" ++ query_obj.? ++ "\n"); + } + } + } self._sql_db.exec(models.createInsertQuery(Type), .{}, values) catch |err| { - std.debug.print("Encountered error while inserting data:\n{any}\n\tQuery:{s}\n{any}\n", .{ values, models.createInsertQuery(Type), err }); + std.debug.print("Encountered error while inserting data:\n\t{any}\nQuery:\t{s}\n{any}\n", .{ values, models.createInsertQuery(Type), err }); return err; }; } diff --git a/src/db/models.zig b/src/db/models.zig index a0eb4a3..4966c2d 100644 --- a/src/db/models.zig +++ b/src/db/models.zig @@ -1,5 +1,31 @@ const std = @import("std"); -const utils = @import("../utils.zig"); + +// fn UniqueSQLType(comptime ValType: type) type { +// comptime { +// _ = switch (@typeInfo(ValType)) { +// .Int, .Float, .Pointer => true, +// else => { +// @compileError("An invalid type passed to unique sql contraint type: " ++ @typeName(ValType)); +// }, +// }; +// } +// comptime { +// const t = switch (@typeInfo(ValType)) { +// .Int => @Type(.{ .Int = .{ .signedness = @typeInfo(ValType).Int.signedness, .bits = @typeInfo(ValType).Int.bits } }), +// .Float => @Type(.{ .Float = .{ .bits = @typeInfo(ValType).Float.bits } }), +// .Pointer => { +// const pt = @typeInfo(ValType).Pointer; +// // _ = std.fmt.comptimePrint("Pointer type info: {any}", .{pt}); +// return @Type(.{ .Pointer = .{ .size = pt.size, .is_const = pt.is_const, .is_volatile = pt.is_volatile, .alignment = pt.alignment, .address_space = pt.address_space, .child = pt.child, .is_allowzero = pt.is_allowzero, .sentinel = pt.sentinel } }); +// }, +// else => unreachable, +// }; +// return t; +// } +// // @setEvalBranchQuota(additional_fields.len * fields.len * 10); +// } + +// pub const UniqueType = union(enum) { string: []const u8, real: f64 }; pub const Transaction = struct { id: u32, @@ -7,7 +33,7 @@ pub const Transaction = struct { type: []const u8, memo: ?[]const u8, budget_id: u32, - added_by_user_id: u32, + created_by_user_id: u32, budget_category_id: ?u32, date: u64, created_at: u64, @@ -28,6 +54,7 @@ pub const BudgetCategory = struct { pub const Budget = struct { id: u32, + family_id: u32, name: []const u8, created_at: u64, updated_at: u64, @@ -37,6 +64,9 @@ pub const Budget = struct { pub const User = struct { id: u32, name: []const u8, + username: ?[]const u8, + email: ?[]const u8, + pass_hash: u32, family_id: u32, budget_id: u32, created_at: u64, @@ -47,13 +77,45 @@ pub const User = struct { pub const Family = struct { id: u32, - budget_id: u32, - hide: u8, + code: ?[]const u8, created_at: u64, updated_at: u64, + hide: u8, }; -pub const ModelTypes = [5]type{ Transaction, BudgetCategory, Budget, User, Family }; +pub const SharedNote = struct { + id: u32, + family_id: u32, + created_by_user_id: u32, + content: []const u8, + title: []const u8, + color: ?[]const u8, + tag_ids: []const u8, + is_markdown: u2, + created_at: u64, + updated_at: u64, + hide: u8, +}; + +pub const Tag = struct { + id: u32, + family_id: u32, + created_by_user_id: u32, + name: []const u8, + type: []const u8, + created_at: u64, + updated_at: u64, + hide: u8, +}; + +pub const Token = struct { + user_id: u32, + family_id: u32, + generated_at: u64, + expires_at: u64, +}; + +pub const ModelTypes = [_]type{ Transaction, BudgetCategory, Budget, User, Family, SharedNote, Tag }; /// Functions for creating SQLite queries for any models above pub inline fn createSelectOnIdQuery(comptime Type: type) []const u8 { @@ -67,7 +129,11 @@ pub inline fn createSelectOnFieldQuery( comptime comparator: []const u8, ) ![]const u8 { comptime { - try std.testing.expect(fieldName == null and structField != null or fieldName != null and structField == null); + if (fieldName == null and structField == null) { + @compileError("Must provide struct and fieldname"); + } else if (fieldName != null and structField != null) { + @compileError("Cannot provide struct and fieldname"); + } var field: []const u8 = undefined; if (structField != null) { field = structField.?; @@ -79,6 +145,19 @@ pub inline fn createSelectOnFieldQuery( } } +pub inline fn createRawSelectQuery(comptime Type: type, comptime where_query: []const u8) ![]const u8 { + comptime { + if (!std.mem.containsAtLeast(u8, where_query, 1, "WHERE")) { + @compileError("Provided where query does not contain 'WHERE' string: " ++ where_query); + } + if (!std.mem.endsWith(u8, where_query, ";")) { + @compileError("Where query does not end with semicolon: " ++ where_query); + } + var query = "SELECT * FROM " ++ getTypeTableName(Type) ++ " " ++ where_query; + return query; + } +} + pub inline fn createDeleteOnIdQuery(comptime Type: type) []const u8 { return "DELETE from " ++ getTypeTableName(Type) ++ " WHERE id = ?;"; } @@ -89,13 +168,14 @@ pub inline fn createInsertQuery(comptime Type: type) []const u8 { var qs: []const u8 = "?"; inline for (@typeInfo(Type).Struct.fields, 0..) |field, i| { // This is brittle, assumes 'id' struct field is first + if (std.mem.eql(u8, field.name, "id")) { + continue; + } if (i > 1) { query = query ++ ", "; qs = qs ++ ", ?"; } - if (i != 0) { - query = query ++ field.name; - } + query = query ++ field.name; } query = query ++ ") VALUES (" ++ qs ++ ");"; return query; @@ -139,6 +219,11 @@ pub inline fn createTableMigrationQuery(comptime Type: type) []const u8 { inline fn getSQLiteColumnMigrateText(comptime struct_field: std.builtin.Type.StructField) []const u8 { comptime { if (std.mem.eql(u8, struct_field.name, "id")) return "INTEGER PRIMARY KEY AUTOINCREMENT"; + if (std.mem.eql(u8, @typeName(struct_field.type), "UniqueSQLType")) { + @compileLog(struct_field.type); + } + + // _ = std.fmt.comptimePrint("Got type {any}", .{@typeInfo(struct_field.type)}); const val = switch (@typeInfo(struct_field.type)) { .Int => "INTEGER NOT NULL", .Float => "REAL NOT NULL", @@ -151,7 +236,14 @@ inline fn getSQLiteColumnMigrateText(comptime struct_field: std.builtin.Type.Str }, .Array => "TEXT NOT NULL", .Pointer => "TEXT NOT NULL", - else => unreachable, + + // .Struct => { + // if (std.mem.eql(u8, @typeName(struct_field.type), "UniqueSQLString") or std.mem.eql(u8, @typeName(struct_field.type), "UniqueSQLString")) + // return "THIS"; + // }, + else => { + @compileError("Passed in a type that has no sql migration defined: " ++ @typeName(struct_field.type)); + }, }; return val; } @@ -165,7 +257,14 @@ pub inline fn getTypeTableName(comptime Type: type) []const u8 { Budget => "budgets", BudgetCategory => "budget_categories", Family => "families", - else => unreachable, + SharedNote => "shared_notes", + Tag => "tags", + // Tag => "tags", + else => { + @compileError("Missing a table name, check to make sure that each model in the models list has a table name provided here.\n" ++ @typeName(Type)); + // _ = std.fmt.comptimePrint("Missing a table name, check to make sure that each model in the models list has a table name provided here.\n{any}", .{ModelTypes}); + // unreachable; + }, }; } } diff --git a/src/http_handler.zig b/src/http_handler.zig index 6bb1e89..16bb12d 100644 --- a/src/http_handler.zig +++ b/src/http_handler.zig @@ -6,9 +6,11 @@ const ztime = @import(".deps/time.zig"); const utils = @import("utils.zig"); const budget = @import("routes/budget.zig"); +const auth = @import("routes/auth.zig"); const user = @import("routes/user.zig"); const trans = @import("routes/transactions.zig"); const dash = @import("routes/dashboard.zig"); +const note = @import("routes/shared_note.zig"); const Db = @import("db/db.zig").Db; @@ -35,22 +37,28 @@ pub fn startHttpServer() !void { var router = server.router(); - router.get("/user/:id", user.getUser); - router.put("/user", user.putUser); - router.delete("/user/:id", user.deleteUser); + router.post("/auth/login", user.login); + router.post("/auth/signup", user.signup); - router.get("/budget/:id", budget.getBudget); - router.put("/budget", budget.putBudget); - router.post("/budget", budget.postBudget); + // router.get("/user/:id", user.getUser); + router.put("/user", user.putUser); + // router.delete("/user/:id", user.deleteUser); + + router.get("/shared_notes/:limit", note.getSharedNotes); + router.put("/shared_notes", note.putSharedNote); + router.post("/shared_notes", note.postSharedNote); + + // router.get("/budget/:id", budget.getBudget); + // router.put("/budget", budget.putBudget); + // router.post("/budget", budget.postBudget); router.put("/budget_category", budget.putBudgetCategory); router.post("/budget_category", budget.postBudgetCategory); - router.get("/transactions/:budget_id", trans.getTransactions); router.post("/transactions", trans.postTransaction); router.put("/transactions", trans.putTransaction); - router.get("/dashboard/:family_id", dash.getDashboard); + router.get("/dashboard", dash.getDashboard); std.debug.print("Starting http server listening on port {}\n", .{8081}); // start the server in the current thread, blocking. @@ -75,8 +83,8 @@ fn errorHandler(req: *httpz.Request, res: *httpz.Response, err: anyerror) void { pub fn returnError(message: ?[]const u8, comptime statusCode: u16, res: *httpz.Response) void { comptime { - if (statusCode < 300 or statusCode > 500) { - @compileError("Failed responses must have status codes between 300 and 500"); + if (statusCode > 500 or statusCode < 200) { + @compileError("Failed responses must have status codes between 200 and 500"); } } res.status = statusCode; diff --git a/src/routes/auth.zig b/src/routes/auth.zig index e69de29..b88c5fd 100644 --- a/src/routes/auth.zig +++ b/src/routes/auth.zig @@ -0,0 +1,90 @@ +const std = @import("std"); +const jwt = @import("../.deps/jwt.zig"); +const httpz = @import("../.deps/http.zig/src/httpz.zig"); +const models = @import("../db/models.zig"); +const ztime = @import("../.deps/time.zig"); +const handler = @import("../http_handler.zig"); +const utils = @import("../utils.zig"); + +//TODO move these to env variables + +const key = "aGVyZUlzQUdpYmVyaXNoS2V5ISE="; +pub const HASH_SEED: u64 = 6065983110; + +pub const VerifyAuthError = error{ + Unauthorized, + NotAuthenticated, + BadToken, + Expired, +}; + +pub fn verifyRequest(req: *httpz.Request, res: *httpz.Response, user_id: ?u32, family_id: ?u32) !models.Token { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + const coded_token = req.headers.get("token"); + const now = + @bitCast(u64, std.time.milliTimestamp()); + const date = ztime.DateTime.now(); + const formatted_now = date.formatAlloc(allocator, "DD.MM.YYYY HH:mm:ss") catch "N/A"; + + const method = @tagName(req.method); + + if (coded_token == null) { + handler.returnError("Unauthorized/NoToken", 401, res); + std.log.info("{s} {s} Unauthorized/NotAuthenticated - @ {s}\n", .{ method, req.url.query, formatted_now }); + return VerifyAuthError.NotAuthenticated; + } + + const token = jwt.validate(models.Token, allocator, .HS256, coded_token.?, .{ .key = key }) catch { + handler.returnError("Unauthorized", 401, res); + std.log.info("{s} {s} Unauthorized/BadToken - Token: {s} @ {s}\n", .{ method, req.url.raw, coded_token.?, formatted_now }); + return VerifyAuthError.BadToken; + }; + + if (user_id != null and user_id.? != token.user_id or family_id != null and family_id.? != token.family_id) { + handler.returnError("Unauthorized", 401, res); + std.log.info("{s} {s} Unauthorized - User: {} Family: {any} @ {s}\n", .{ method, req.url.raw, token.user_id, token.family_id, formatted_now }); + return VerifyAuthError.Unauthorized; + } + + if (token.expires_at < now) { + std.log.info("{s} {s} Unauthorized/Expired - User: {} Family: {any} @ {s}\n", .{ method, req.url.raw, token.user_id, token.family_id, formatted_now }); + handler.returnError("Credentials Expired", 403, res); + return VerifyAuthError.Expired; + } + + std.log.info("{s} {s} Authorized - User: {} Family: {any} @ {s}\n", .{ method, req.url.raw, token.user_id, token.family_id, formatted_now }); + return token; +} + +pub fn generateToken(user: models.User) ![]const u8 { + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + const now = @bitCast(u64, std.time.milliTimestamp()); + var seven_days = ztime.DateTime.now(); + seven_days = seven_days.addDays(7); + const token: models.Token = .{ .user_id = user.id, .family_id = user.family_id, .generated_at = now, .expires_at = seven_days.toUnixMilli() }; + const encoded_token = try jwt.encode(allocator, .HS256, token, .{ .key = key }); + std.log.info("Generated token for User {} @ {}", .{ user.id, now }); + return encoded_token; +} + +test { + var gpa = std.testing.allocator_instance; + const allocator = gpa.allocator(); + + const user = models.User{ .id = 1, .family_id = 1, .budget_id = 1, .name = "Billy", .created_at = 1, .updated_at = 1, .last_activity_at = 1, .hide = 0 }; + const user_token = try generateToken(user); + + // std.debug.print("user_token: {s}\n", .{user_token}); + // var key = try std.testing.allocator.alloc(u8, try base64url.Decoder.calcSizeForSlice(key_base64)); + // defer std.testing.allocator.free(key); + // try base64url.Decoder.decode(key, key_base64); + + // std.debug.print("key: {s}\n", .{key}); + + // const user_token: models.Token = .{ .user_id = 1, .family_id = 1, .generated_at = @bitCast(u64, std.time.milliTimestamp()), .expires_at = @bitCast(u64, std.time.milliTimestamp()) }; + // const coded_message = try jwt.encode(allocator, .HS256, user_token, .{ .key = key }); + const decoded_message = try jwt.validate(models.Token, allocator, .HS256, user_token, .{ .key = key }); + std.debug.print("user: {any}\nuser_token: {s}\ndecoded: {any}\n", .{ user, user_token, decoded_message }); +} diff --git a/src/routes/budget.zig b/src/routes/budget.zig index 9e74123..c782d20 100644 --- a/src/routes/budget.zig +++ b/src/routes/budget.zig @@ -3,122 +3,123 @@ const httpz = @import("../.deps/http.zig/src/httpz.zig"); const models = @import("../db/models.zig"); const ztime = @import("../.deps/time.zig"); const utils = @import("../utils.zig"); - +const auth = @import("auth.zig"); const handler = @import("../http_handler.zig"); -pub fn getBudget(req: *httpz.Request, res: *httpz.Response) !void { - const db = handler.getDb(); +// pub fn getBudget(req: *httpz.Request, res: *httpz.Response) !void { +// const db = handler.getDb(); - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - const allocator = gpa.allocator(); - const id_str = req.param("id"); - if (id_str == null) { - res.status = 400; - res.body = "Bad Request: No Id"; - return; - } - const id = std.fmt.parseInt(u32, id_str.?, 0) catch { - res.status = 401; - res.body = "Bad Request: Bad Id"; - return; - }; +// var gpa = std.heap.GeneralPurposeAllocator(.{}){}; +// const allocator = gpa.allocator(); +// const id_str = req.param("id"); +// if (id_str == null) { +// res.status = 400; +// res.body = "Bad Request: No Id"; +// return; +// } +// const id = std.fmt.parseInt(u32, id_str.?, 0) catch { +// res.status = 401; +// res.body = "Bad Request: Bad Id"; +// return; +// }; - const budget = try db.selectOneById(models.Budget, allocator, id); +// const budget = try db.selectOneById(models.Budget, allocator, id); - if (budget == null) { - res.status = 404; - res.body = "Budget not found"; - return; - } +// if (budget == null) { +// res.status = 404; +// res.body = "Budget not found"; +// return; +// } - try res.json(budget.?, .{}); -} +// try res.json(budget.?, .{}); +// } -const BudgetPostReq = struct { - id: ?u32, - name: []const u8, - created_at: ?u64, - updated_at: ?u64, - hide: u8, -}; +// const BudgetPostReq = struct { +// id: ?u32, +// family_id: u32, +// name: []const u8, +// created_at: ?u64, +// updated_at: ?u64, +// hide: u8, +// }; -pub fn putBudget(req: *httpz.Request, res: *httpz.Response) !void { - var db = handler.getDb(); - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - const allocator = gpa.allocator(); +// pub fn putBudget(req: *httpz.Request, res: *httpz.Response) !void { +// var db = handler.getDb(); +// var gpa = std.heap.GeneralPurposeAllocator(.{}){}; +// const allocator = gpa.allocator(); - const body_data = req.json(models.Budget) catch |err| { - std.debug.print("Malformed body: {any}\n", .{err}); - handler.returnError("Bad Request: Malformed Body", 400, res); - return; - }; - if (body_data == null) { - handler.returnError("Bad Request: No Data", 400, res); - return; - } - var body = body_data.?; +// const body_data = req.json(models.Budget) catch |err| { +// std.debug.print("Malformed body: {any}\n", .{err}); +// handler.returnError("Bad Request: Malformed Body", 400, res); +// return; +// }; +// if (body_data == null) { +// handler.returnError("Bad Request: No Data", 400, res); +// return; +// } +// var body = body_data.?; - // Add Budget - const now = @intCast(u64, std.time.milliTimestamp()); - // Update existing Budget - body.updated_at = now; - try db.updateById(models.Budget, body); +// // Add Budget +// const now = @intCast(u64, std.time.milliTimestamp()); +// // Update existing Budget +// body.updated_at = now; +// try db.updateById(models.Budget, body); - const query = models.createSelectOnIdQuery(models.Transaction); - const updated_budget = try db.selectOne(models.Budget, allocator, query, .{ .id = body.id }); - if (updated_budget) |budget| { - try handler.returnData(budget, res); - } else { - handler.returnError("Internal Server Error", 500, res); - } - return; -} +// const query = models.createSelectOnIdQuery(models.Transaction); +// const updated_budget = try db.selectOne(models.Budget, allocator, query, .{ .id = body.id }); +// if (updated_budget) |budget| { +// try handler.returnData(budget, res); +// } else { +// handler.returnError("Internal Server Error", 500, res); +// } +// return; +// } -pub fn postBudget(req: *httpz.Request, res: *httpz.Response) !void { - comptime { - const putReqLen = @typeInfo(BudgetPostReq).Struct.fields.len; - const budgetLen = @typeInfo(models.Budget).Struct.fields.len; - if (putReqLen != budgetLen) { - @compileError(std.fmt.comptimePrint("BudgetPutReq does not equal Budget model struct, fields inconsistent", .{})); - } - } +// pub fn postBudget(req: *httpz.Request, res: *httpz.Response) !void { +// comptime { +// const putReqLen = @typeInfo(BudgetPostReq).Struct.fields.len; +// const budgetLen = @typeInfo(models.Budget).Struct.fields.len; +// if (putReqLen != budgetLen) { +// @compileError(std.fmt.comptimePrint("BudgetPostReq does not equal Budget model struct, fields inconsistent", .{})); +// } +// } - var db = handler.getDb(); - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - const allocator = gpa.allocator(); +// var db = handler.getDb(); +// var gpa = std.heap.GeneralPurposeAllocator(.{}){}; +// const allocator = gpa.allocator(); - const body_data = req.json(BudgetPostReq) catch |err| { - std.debug.print("Malformed body: {any}\n", .{err}); - handler.returnError("Bad request: Malformed Body", 400, res); - return; - }; - if (body_data == null) { - handler.returnError("Bad request: No Data", 400, res); - return; - } - var body = body_data.?; +// const body_data = req.json(BudgetPostReq) catch |err| { +// std.debug.print("Malformed body: {any}\n", .{err}); +// handler.returnError("Bad request: Malformed Body", 400, res); +// return; +// }; +// if (body_data == null) { +// handler.returnError("Bad request: No Data", 400, res); +// return; +// } +// var body = body_data.?; - if (body.id != null) { - handler.returnError("Bad Request: ID", 400, res); - } - // Add Budget - const now = @intCast(u64, std.time.milliTimestamp()); - // Create Budget - body.created_at = now; - body.updated_at = now; +// if (body.id != null) { +// handler.returnError("Bad Request: ID", 400, res); +// } +// // Add Budget +// const now = @intCast(u64, std.time.milliTimestamp()); +// // Create Budget +// body.created_at = now; +// body.updated_at = now; - try db.insert(models.Budget, utils.removeStructFields(body, &[_]u8{0})); +// try db.insert(models.Budget, utils.removeStructFields(body, &[_]u8{0})); - // Get Budget - const query = try models.createSelectOnFieldQuery(models.Budget, null, "created_at", "="); - const updated_budget = try db.selectOne(models.Budget, allocator, query, .{ .created_at = body.created_at }); - if (updated_budget) |budget| { - try handler.returnData(budget, res); - } else { - handler.returnError("Internal Server Error", 500, res); - } - return; -} +// // Get Budget +// const query = try models.createSelectOnFieldQuery(models.Budget, null, "created_at", "="); +// const updated_budget = try db.selectOne(models.Budget, allocator, query, .{ .created_at = body.created_at }); +// if (updated_budget) |budget| { +// try handler.returnData(budget, res); +// } else { +// handler.returnError("Internal Server Error", 500, res); +// } +// return; +// } const BudgetCatPostReq = struct { id: ?u32, @@ -160,6 +161,16 @@ pub fn postBudgetCategory(req: *httpz.Request, res: *httpz.Response) !void { handler.returnError("Bad request: ID", 400, res); return; } + + const budget = try db.selectOneById(models.Budget, allocator, body.budget_id); + if (budget == null) { + handler.returnError("No budget found", 404, res); + } + + _ = auth.verifyRequest(req, res, null, budget.?.family_id) catch { + return; + }; + // Add Budget const now = @intCast(u64, std.time.milliTimestamp()); // Create Budget @@ -170,14 +181,14 @@ pub fn postBudgetCategory(req: *httpz.Request, res: *httpz.Response) !void { // Get Budget const query = try models.createSelectOnFieldQuery(models.BudgetCategory, null, "created_at", "="); - const updated_budget = try db.selectOne(models.BudgetCategory, allocator, query, .{ .created_at = body.created_at }); - if (updated_budget) |budget| { - try handler.returnData(budget, res); - } else { + const updated_budget_cat = try db.selectOne(models.BudgetCategory, allocator, query, .{ .created_at = body.created_at }); + + if (updated_budget_cat == null) { std.debug.print("Could not find inserted budget", .{}); handler.returnError("Internal Server Error", 500, res); + return; } - return; + try handler.returnData(updated_budget_cat.?, res); } pub fn putBudgetCategory(req: *httpz.Request, res: *httpz.Response) !void { @@ -196,6 +207,15 @@ pub fn putBudgetCategory(req: *httpz.Request, res: *httpz.Response) !void { } var budget_category = body_data.?; + const budget = try db.selectOneById(models.Budget, allocator, budget_category.budget_id); + if (budget == null) { + handler.returnError("No budget found", 404, res); + } + + _ = auth.verifyRequest(req, res, null, budget.?.family_id) catch { + return; + }; + const now = @intCast(u64, std.time.milliTimestamp()); // Update existing Budget @@ -203,12 +223,12 @@ pub fn putBudgetCategory(req: *httpz.Request, res: *httpz.Response) !void { try db.updateById(models.BudgetCategory, budget_category); const query = models.createSelectOnIdQuery(models.BudgetCategory); - const updated_budget = try db.selectOne(models.BudgetCategory, allocator, query, .{ .id = budget_category.id }); - if (updated_budget) |budget| { - try handler.returnData(budget, res); - } else { + const updated_budget_cat = try db.selectOne(models.BudgetCategory, allocator, query, .{ .id = budget_category.id }); + if (updated_budget_cat == null) { std.debug.print("Could not find inserted budget", .{}); handler.returnError("Internal Server Error", 500, res); + return; } + try handler.returnData(updated_budget_cat.?, res); return; } diff --git a/src/routes/dashboard.zig b/src/routes/dashboard.zig index 932b895..248794d 100644 --- a/src/routes/dashboard.zig +++ b/src/routes/dashboard.zig @@ -4,7 +4,9 @@ const models = @import("../db/models.zig"); const ztime = @import("../.deps/time.zig"); const utils = @import("../utils.zig"); const trans = @import("transactions.zig"); +const note = @import("shared_note.zig"); const handler = @import("../http_handler.zig"); +const auth = @import("auth.zig"); pub fn getDashboard(req: *httpz.Request, res: *httpz.Response) !void { const db = handler.getDb(); @@ -12,41 +14,62 @@ pub fn getDashboard(req: *httpz.Request, res: *httpz.Response) !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; const allocator = gpa.allocator(); - const family_id_str = req.param("family_id"); - if (family_id_str == null) { - res.status = 400; - res.body = "Bad Request: No FamilyId"; - return; - } - const family_id = std.fmt.parseInt(u32, family_id_str.?, 0) catch { - res.status = 400; - res.body = "Bad Request: Bad FamilyId"; + const token = auth.verifyRequest(req, res, null, null) catch { return; }; - const family = db.selectOneById(models.Family, allocator, family_id) catch |err| { - if (err == error.SQLiteError) { - res.status = 404; - res.body = "Family Not Found"; - return; - } - std.debug.print("Error while getting family: {}\n", .{err}); - res.status = 500; - res.body = "Internal Server Error"; - return; + // const family_id = std.fmt.parseInt(u32, family_id_str.?, 0) catch { + // res.status = 400; + // res.body = "Bad Request: Bad FamilyId"; + // return; + // }; + + const family = db.selectOneById(models.Family, allocator, token.family_id) catch |err| { + std.log.err("Error getting family in dashboard: {any}", .{err}); + return handler.returnError("Unexpected Server Error", 500, res); }; + if (family == null) { - res.status = 404; - res.body = "Family Not Found"; - return; + std.log.err("Family not found, invalidating client token", .{}); + return handler.returnError("Family does not exist or forbidden", 403, res); } - const transactions = try trans.fetchTransFromDb(allocator, family.?.budget_id); + const user = db.selectOneById(models.User, allocator, token.user_id) catch |err| { + std.log.err("User not found, invalidating client token, Err: {any}", .{err}); + return handler.returnError("Could not get user or forbidden", 403, res); + }; + + if (user == null) { + std.log.err("User not found, invalidating client token", .{}); + return handler.returnError("User not found or forbidden", 403, res); + } + + const transactions = trans.fetchTransFromDb(allocator, family.?.id) catch |err| { + std.log.err("Unexpected error while getting transactions: {any}", .{err}); + handler.returnError("Internal Server Error", 500, res); + return; + }; + + const notes = note.fetchNotesFromDb(allocator, family.?.id) catch |err| { + std.log.err("Unexpected error while getting transactions: {any}", .{err}); + handler.returnError("Internal Server Error", 500, res); + return; + }; + + const budget_query = try models.createSelectOnFieldQuery(models.Budget, null, "family_id", "="); + const budget = db.selectOne(models.Budget, allocator, budget_query, .{ .family_id = token.family_id }) catch |err| { + std.log.err("Unexpected error while getting budget: {any}", .{err}); + handler.returnError("Internal Server Error", 500, res); + return; + }; - const budget = try db.selectOneById(models.Budget, allocator, family.?.budget_id); var budget_categories: ?[]models.BudgetCategory = null; if (budget != null) { - budget_categories = try db.selectAllWhere(models.BudgetCategory, allocator, "WHERE budget_id = ? AND hide = ?", .{ .budget_id = budget.?.id, .hide = 0 }, null); + budget_categories = db.selectAllWhere(models.BudgetCategory, allocator, "WHERE budget_id = ? AND hide = ?", .{ .budget_id = budget.?.id, .hide = 0 }, null, null, null, null) catch |err| { + std.log.err("Unexpected error while getting budget categories: {any}", .{err}); + handler.returnError("Internal Server Error", 500, res); + return; + }; } if (budget_categories == null) { @@ -54,9 +77,12 @@ pub fn getDashboard(req: *httpz.Request, res: *httpz.Response) !void { } const response_body = .{ .family = family.?, + .user = user.?, .budget = budget, .budget_categories = budget_categories, .transactions = transactions, + .shared_notes = notes, + .success = true, }; try res.json(response_body, .{}); } diff --git a/src/routes/shared_note.zig b/src/routes/shared_note.zig new file mode 100644 index 0000000..611f180 --- /dev/null +++ b/src/routes/shared_note.zig @@ -0,0 +1,155 @@ +const std = @import("std"); +const httpz = @import("../.deps/http.zig/src/httpz.zig"); +const models = @import("../db/models.zig"); +// const ztime = @import("../.deps/time.zig"); +const utils = @import("../utils.zig"); + +const auth = @import("auth.zig"); +const handler = @import("../http_handler.zig"); + +pub fn fetchNotesFromDb(allocator: std.mem.Allocator, family_id: u32) !?[]models.SharedNote { + var db = handler.getDb(); + + const notes = db.selectAllWhere( + models.SharedNote, + allocator, + "WHERE family_id = ? and hide = ?", + .{ .family_id = family_id, .hide = 0 }, + "updated_at", + "DESC", + true, + 10, + ) catch |err| { + std.debug.print("Error while getting shared notes: {any}", .{err}); + return err; + }; + return notes; +} + +pub fn getSharedNotes(req: *httpz.Request, res: *httpz.Response) !void { + var db = handler.getDb(); + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + + var str_limit: ?[]const u8 = req.param("limit"); + if (str_limit == null) { + str_limit = "10"; + } + const limit = std.fmt.parseInt(u32, str_limit.?, 0) catch { + handler.returnError("Bad Request: Bad Limit", 401, res); + return; + }; + + const token = auth.verifyRequest(req, res, null, null) catch { + return; + }; + + const notes = db.selectAllWhere( + models.SharedNote, + allocator, + "WHERE family_id = ? and hide = ?", + .{ .family_id = token.family_id, .hide = 0 }, + "updated_at", + "DESC", + true, + limit, + ) catch |err| { + std.debug.print("Error while getting shared notes: {any}", .{err}); + return err; + }; + + try res.json(.{ .notes = notes }, .{}); + return; +} + +pub fn putSharedNote(req: *httpz.Request, res: *httpz.Response) !void { + var db = handler.getDb(); + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + + const body_data = req.json(models.SharedNote) catch |err| { + std.debug.print("Malformed body: {any}\n", .{err}); + handler.returnError("Bad request: Malformed Body", 400, res); + return; + }; + if (body_data == null) { + handler.returnError("Bad request: No Data", 400, res); + return; + } + var shared_note = body_data.?; + + _ = auth.verifyRequest(req, res, shared_note.created_by_user_id, shared_note.family_id) catch { + return; + }; + + const now = @intCast(u64, std.time.milliTimestamp()); + shared_note.updated_at = now; + try db.updateById(models.SharedNote, shared_note); + + const updated_note = try db.selectOneById(models.SharedNote, allocator, shared_note.id); + if (updated_note) |note| { + try handler.returnData(note, res); + } else { + handler.returnError("Internal Server Error", 500, res); + } + return; +} + +const NotePostReq = struct { + id: ?u32, + family_id: u32, + created_by_user_id: u32, + content: []const u8, + title: []const u8, + color: ?[]const u8, + tag_ids: []const u8, + is_markdown: u2, + created_at: ?u64, + updated_at: ?u64, + hide: u8, +}; + +pub fn postSharedNote(req: *httpz.Request, res: *httpz.Response) !void { + comptime { + const postReqLen = @typeInfo(NotePostReq).Struct.fields.len; + const noteLen = @typeInfo(models.SharedNote).Struct.fields.len; + if (postReqLen != noteLen) { + @compileError(std.fmt.comptimePrint("SharedNotePutReq does not equal SharedNote model struct, fields inconsistent", .{})); + } + } + var db = handler.getDb(); + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + + const body_data = req.json(NotePostReq) catch |err| { + std.debug.print("Malformed body: {any}\n", .{err}); + handler.returnError("Bad request: Malformed Body", 400, res); + return; + }; + if (body_data == null) { + handler.returnError("Bad request: No Data", 400, res); + return; + } + var body = body_data.?; + + _ = auth.verifyRequest(req, res, body.created_by_user_id, body.family_id) catch { + return; + }; + + const now = @intCast(u64, std.time.milliTimestamp()); + body.created_at = now; + body.updated_at = now; + + // remove the null id field for insertion + try db.insert(models.SharedNote, utils.removeStructFields(body, &[_]u8{0})); + + // Get new SharedNote + const query = try models.createSelectOnFieldQuery(models.SharedNote, null, "created_at", "="); + const updated_note = try db.selectOne(models.SharedNote, allocator, query, .{ .created_at = body.created_at }); + if (updated_note) |note| { + try handler.returnData(note, res); + } else { + handler.returnError("Internal Server Error", 500, res); + } + return; +} diff --git a/src/routes/transactions.zig b/src/routes/transactions.zig index 33db3bd..f657866 100644 --- a/src/routes/transactions.zig +++ b/src/routes/transactions.zig @@ -4,9 +4,10 @@ const models = @import("../db/models.zig"); const ztime = @import("../.deps/time.zig"); const utils = @import("../utils.zig"); +const auth = @import("auth.zig"); const handler = @import("../http_handler.zig"); -pub fn fetchTransFromDb(allocator: std.mem.Allocator, budget_id: u32) !?[]models.Transaction { +pub fn fetchTransFromDb(allocator: std.mem.Allocator, family_id: u32) !?[]models.Transaction { var db = handler.getDb(); const now = ztime.DateTime.now(); const beginningOfMonth = ztime.DateTime.init(now.years, now.months, 0, 0, 0, 0); @@ -16,43 +17,61 @@ pub fn fetchTransFromDb(allocator: std.mem.Allocator, budget_id: u32) !?[]models return error{TransactionModelError}; } } - const transactions = try db.selectAllWhere( + const budget_query = models.createSelectOnFieldQuery(models.Budget, null, "family_id", "=") catch |err| { + std.debug.print("Error while creating budget query: {any}", .{err}); + return err; + }; + const budget = db.selectOne(models.Budget, allocator, budget_query, .{ .family_id = family_id }) catch |err| { + std.debug.print("Error while getting budget: {any}", .{err}); + return err; + }; + if (budget == null) { + return null; + } + const transactions = db.selectAllWhere( models.Transaction, allocator, "WHERE budget_id = ? AND date > ? AND hide = ?", - .{ .budget_id = budget_id, .date = beginningOfMonth.toUnixMilli(), .hide = 0 }, + .{ .budget_id = budget.?.id, .date = beginningOfMonth.toUnixMilli(), .hide = 0 }, null, - ); + null, + null, + null, + ) catch |err| { + std.debug.print("Error while getting transactions query: {any}", .{err}); + return err; + }; return transactions; } -pub fn getTransactions(req: *httpz.Request, res: *httpz.Response) !void { - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - const allocator = gpa.allocator(); +// pub fn getTransactions(req: *httpz.Request, res: *httpz.Response) !void { +// auth.verifyRequest(); +// var gpa = std.heap.GeneralPurposeAllocator(.{}){}; +// const allocator = gpa.allocator(); - const budget_id_str = req.param("budget_id"); - if (budget_id_str == null) { - handler.returnError("Bad request", 400, res); - return; - } - const budget_id = std.fmt.parseInt(u32, budget_id_str.?, 0) catch { - handler.returnError("Bad request", 400, res); - return; - }; +// const budget_id_str = req.param("budget_id"); +// if (budget_id_str == null) { +// handler.returnError("Bad request", 400, res); +// return; +// } +// const budget_id = std.fmt.parseInt(u32, budget_id_str.?, 0) catch { +// handler.returnError("Bad request", 400, res); +// return; +// }; - const transactions = try fetchTransFromDb(allocator, budget_id); +// const transactions = try fetchTransFromDb(allocator, budget_id); - if (transactions == null) { - res.status = 200; - res.body = ""; - return; - } - // std.debug.print("Transactions got:\n", .{}); - // for (transactions.?) |transaction| { - // std.debug.print("\t{any}\n", .{transaction}); - // } - try handler.returnData(.{ .transactions = transactions.? }, res); -} +// if (transactions == null) { +// res.status = 200; +// res.body = ""; +// return; +// } +// // std.debug.print("Transactions got:\n", .{}); +// // for (transactions.?) |transaction| { +// // std.debug.print("\t{any}\n", .{transaction}); +// // } +// try handler.returnData(.{ .transactions = transactions.? }, res); +// } pub fn putTransaction(req: *httpz.Request, res: *httpz.Response) !void { var db = handler.getDb(); @@ -70,9 +89,16 @@ pub fn putTransaction(req: *httpz.Request, res: *httpz.Response) !void { } var transaction = body_data.?; - // Add Transaction + const budget = try db.selectOneById(models.Budget, allocator, transaction.budget_id); + if (budget == null) { + return handler.returnError("Budget for transaction not found or forbidden", 403, res); + } + + _ = auth.verifyRequest(req, res, transaction.created_by_user_id, budget.?.family_id) catch { + return; + }; + const now = @intCast(u64, std.time.milliTimestamp()); - // Update existing Transaction transaction.updated_at = now; try db.updateById(models.Transaction, transaction); @@ -92,7 +118,7 @@ const TransPostReq = struct { type: []const u8, memo: ?[]const u8, budget_id: u32, - added_by_user_id: u32, + created_by_user_id: u32, budget_category_id: ?u32, date: u64, created_at: ?u64, @@ -102,9 +128,9 @@ const TransPostReq = struct { pub fn postTransaction(req: *httpz.Request, res: *httpz.Response) !void { comptime { - const putReqLen = @typeInfo(TransPostReq).Struct.fields.len; + const postReqLen = @typeInfo(TransPostReq).Struct.fields.len; const transLen = @typeInfo(models.Transaction).Struct.fields.len; - if (putReqLen != transLen) { + if (postReqLen != transLen) { @compileError(std.fmt.comptimePrint("TransactionPutReq does not equal Transaction model struct, fields inconsistent", .{})); } } @@ -123,15 +149,19 @@ pub fn postTransaction(req: *httpz.Request, res: *httpz.Response) !void { } var body = body_data.?; - if (body.id != null) { - handler.returnError("Bad request: ID", 400, res); - return; + const budget = try db.selectOneById(models.Budget, allocator, body.budget_id); + if (budget == null) { + return handler.returnError("Budget for transaction not foud or forbidden", 403, res); } + _ = auth.verifyRequest(req, res, body.created_by_user_id, budget.?.family_id) catch { + return; + }; const now = @intCast(u64, std.time.milliTimestamp()); body.created_at = now; body.updated_at = now; + // remove the null id field for insertion try db.insert(models.Transaction, utils.removeStructFields(body, &[_]u8{0})); // Get new Transaction diff --git a/src/routes/user.zig b/src/routes/user.zig index e928d70..cd4530d 100644 --- a/src/routes/user.zig +++ b/src/routes/user.zig @@ -2,9 +2,202 @@ const std = @import("std"); const httpz = @import("../.deps/http.zig/src/httpz.zig"); const models = @import("../db/models.zig"); const utils = @import("../utils.zig"); - +const auth = @import("auth.zig"); const handler = @import("../http_handler.zig"); +const LoginReq = struct { + username: ?[]const u8, + email: ?[]const u8, + password: []const u8, +}; + +// POST endpoint for user login requests +pub fn login(req: *httpz.Request, res: *httpz.Response) !void { + const db = handler.getDb(); + + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + + const body_data = req.json(LoginReq) catch |err| { + std.debug.print("Malformed body: {any}\n", .{err}); + handler.returnError("Bad Request: Malformed Body", 400, res); + return; + }; + + if (body_data == null) { + handler.returnError("Bad Request: No Data", 400, res); + return; + } + + var body = body_data.?; + if (body.username == null and body.email == null) { + handler.returnError("Bad Request: Missing username / email", 400, res); + return; + } else if (body.username != null and body.email != null) { + handler.returnError("Bad Request: Provided Username and Email", 400, res); + return; + } + + var user: ?models.User = null; + const password_hash = @truncate(u32, std.hash.Wyhash.hash(auth.HASH_SEED, body.password)); + if (body.username != null) { + const query = + "WHERE pass_hash = ? and username = ?;"; + user = try db.selectOne(models.User, allocator, try models.createRawSelectQuery( + models.User, + query, + ), .{ .pass_hash = password_hash, .username = body.username.? }); + } else { + const query = + "WHERE pass_hash = ? and email = ?;"; + user = try db.selectOne(models.User, allocator, try models.createRawSelectQuery( + models.User, + query, + ), .{ .pass_hash = password_hash, .email = body.email.? }); + } + + if (user == null) { + handler.returnError("User not found or incorrect password", 200, res); + return; + } else if (user.?.hide == 1) { + handler.returnError("Account has been closed", 200, res); + return; + } + const token = try auth.generateToken(user.?); + + try handler.returnData(.{ .token = token }, res); +} + +const SignupReq = struct { + name: []const u8, + username: []const u8, + email: ?[]const u8, + password: []const u8, + family_code: ?[]const u8, + budget_name: []const u8, +}; + +// POST endpoint for user signups +pub fn signup(req: *httpz.Request, res: *httpz.Response) !void { + const db = handler.getDb(); + + var gpa = std.heap.GeneralPurposeAllocator(.{}){}; + const allocator = gpa.allocator(); + + const body_data = req.json(SignupReq) catch |err| { + std.debug.print("Malformed body: {any}\n", .{err}); + handler.returnError("Bad Request: Malformed Body", 400, res); + return; + }; + + if (body_data == null) { + handler.returnError("Bad Request: No Data", 400, res); + return; + } + + var body = body_data.?; + // if (body.username == null and body.email == null) { + // handler.returnError("Bad Request: Missing username / email", 400, res); + // return; + // } + + const password_hash = @truncate(u32, std.hash.Wyhash.hash(auth.HASH_SEED, body.password)); + const now = @bitCast(u64, std.time.milliTimestamp()); + + const uname_query = + "WHERE username = ?;"; + var temp_user = try db.selectOne(models.User, allocator, try models.createRawSelectQuery( + models.User, + uname_query, + ), .{ .username = body.username }); + + if (temp_user != null) { + handler.returnError("Username is unavailable", 200, res); + return; + } + + if (body.email != null) { + const email_query = + "WHERE email = ?;"; + temp_user = try db.selectOne(models.User, allocator, try models.createRawSelectQuery( + models.User, + email_query, + ), .{ .email = body.email.? }); + + if (temp_user != null) { + handler.returnError("Email is unavailable", 200, res); + return; + } + } + + var family: ?models.Family = null; + var budget: ?models.Budget = null; + var user: ?models.User = null; + if (body.family_code == null) { + const new_family = .{ + .code = utils.generateRandomString(allocator) catch null, + .created_at = now, + .updated_at = now, + .hide = 0, + }; + try db.insert(models.Family, new_family); + const family_query = try models.createSelectOnFieldQuery(models.Family, null, "code", "="); + family = try db.selectOne(models.Family, allocator, family_query, .{ .code = new_family.code }); + + if (family == null) { + handler.returnError("Unable to create family", 500, res); + return; + } + + const new_budget = .{ + .family_id = family.?.id, + .name = body.budget_name, + .created_at = now, + .updated_at = now, + .hide = 0, + }; + try db.insert(models.Budget, new_budget); + const budget_query = try models.createSelectOnFieldQuery(models.Budget, null, "created_at", "="); + budget = try db.selectOne(models.Budget, allocator, budget_query, .{ .created_at = now }); + + if (budget == null) { + handler.returnError("Unable to create budget", 500, res); + return; + } + } else { + const family_query = try models.createSelectOnFieldQuery(models.Family, null, "code", "="); + family = try db.selectOne(models.Family, allocator, family_query, .{ .code = body.family_code.? }); + + if (family == null) { + return handler.returnError("Invalid Family Code", 404, res); + } + + const budget_query = try models.createSelectOnFieldQuery(models.Budget, null, "family_id", "="); + budget = try db.selectOne(models.Budget, allocator, budget_query, .{ .family_id = family.?.id }); + if (budget == null) { + return handler.returnError("Could not find Family Budget", 404, res); + } + } + + const new_user = .{ .name = body.name, .username = body.username, .email = body.email, .pass_hash = password_hash, .family_id = family.?.id, .budget_id = budget.?.id, .created_at = now, .updated_at = now, .last_activity_at = now, .hide = 0 }; + + try db.insert(models.User, new_user); + + const query = try models.createSelectOnFieldQuery(models.User, null, "pass_hash", "="); + user = try db.selectOne(models.User, allocator, query, .{ .pass_hash = password_hash }); + + if (user == null) { + handler.returnError("Unable to create new user", 500, res); + return; + } + + std.log.info("User created: {any}\nFamily created: {any}\n", .{ user.?, family.? }); + + const token = try auth.generateToken(user.?); + + try handler.returnData(.{ .user = user, .token = token }, res); +} + pub fn getUser(req: *httpz.Request, res: *httpz.Response) !void { const db = handler.getDb(); @@ -23,7 +216,10 @@ pub fn getUser(req: *httpz.Request, res: *httpz.Response) !void { const user = try db.selectOneById(models.User, allocator, id); if (user == null) { - handler.returnError("Error: User Not Found", 404, res); + _ = auth.verifyRequest(req, res, id, user.family_id) catch { + return; + }; + handler.returnError("Error: User Not Found of Forbidden", 403, res); return; } @@ -31,7 +227,6 @@ pub fn getUser(req: *httpz.Request, res: *httpz.Response) !void { } const UserPostReq = struct { - id: ?u32, name: []const u8, family_id: u32, budget_id: u32, @@ -57,6 +252,10 @@ pub fn putUser(req: *httpz.Request, res: *httpz.Response) !void { } var body = body_data.?; + _ = auth.verifyRequest(req, res, body.id, body.family_id) catch { + return; + }; + // Add User const now = @intCast(u64, std.time.milliTimestamp()); // Update existing User @@ -76,74 +275,74 @@ pub fn putUser(req: *httpz.Request, res: *httpz.Response) !void { // try res.json(user, .{}); } -pub fn postUser(req: *httpz.Request, res: *httpz.Response) !void { - comptime { - const putReqLen = @typeInfo(UserPostReq).Struct.fields.len; - const userLen = @typeInfo(models.User).Struct.fields.len; - if (putReqLen != userLen) { - @compileError(std.fmt.comptimePrint("UserPutReq does not equal User model struct, fields inconsistent", .{})); - } - } +// pub fn postUser(req: *httpz.Request, res: *httpz.Response) !void { +// comptime { +// const putReqLen = @typeInfo(UserPostReq).Struct.fields.len; +// const userLen = @typeInfo(models.User).Struct.fields.len; +// if (putReqLen != userLen) { +// @compileError(std.fmt.comptimePrint("UserPutReq does not equal User model struct, fields inconsistent", .{})); +// } +// } - const db = handler.getDb(); - var gpa = std.heap.GeneralPurposeAllocator(.{}){}; - const allocator = gpa.allocator(); +// const db = handler.getDb(); +// var gpa = std.heap.GeneralPurposeAllocator(.{}){}; +// const allocator = gpa.allocator(); - const body_data = req.json(UserPostReq) catch |err| { - std.debug.print("Malformed body: {any}\n", .{err}); - handler.returnError("Bad request: Malformed Body", 400, res); - return; - }; - if (body_data == null) { - handler.returnError("Bad request: No Data", 400, res); - return; - } - var body = body_data.?; +// const body_data = req.json(UserPostReq) catch |err| { +// std.debug.print("Malformed body: {any}\n", .{err}); +// handler.returnError("Bad request: Malformed Body", 400, res); +// return; +// }; +// if (body_data == null) { +// handler.returnError("Bad request: No Data", 400, res); +// return; +// } +// var body = body_data.?; - if (body.id != null) { - handler.returnError("Bad request: ID", 400, res); - return; - } - // Add User - const now = @intCast(u64, std.time.milliTimestamp()); - // Create User - body.created_at = now; - body.last_activity_at = now; - body.updated_at = now; +// if (body.id != null) { +// handler.returnError("Bad request: ID", 400, res); +// return; +// } +// // Add User +// const now = @intCast(u64, std.time.milliTimestamp()); +// // Create User +// body.created_at = now; +// body.last_activity_at = now; +// body.updated_at = now; - try db.insert(models.User, utils.removeStructFields(body, &[_]u8{0})); +// try db.insert(models.User, utils.removeStructFields(body, &[_]u8{0})); - // Get new User - const query = try models.createSelectOnFieldQuery(models.User, null, "created_at", "="); - const updated_user = try db.selectOne(models.User, allocator, query, .{ .created_at = body.created_at }); - if (updated_user) |user| { - try handler.returnData(user, res); - } else { - handler.returnError("Internal Server Error", 500, res); - } - return; -} +// // Get new User +// const query = try models.createSelectOnFieldQuery(models.User, null, "created_at", "="); +// const updated_user = try db.selectOne(models.User, allocator, query, .{ .created_at = body.created_at }); +// if (updated_user) |user| { +// try handler.returnData(user, res); +// } else { +// handler.returnError("Internal Server Error", 500, res); +// } +// return; +// } -pub fn deleteUser(req: *httpz.Request, res: *httpz.Response) !void { - const db = handler.getDb(); +// pub fn deleteUser(req: *httpz.Request, res: *httpz.Response) !void { +// const db = handler.getDb(); - const user_id = req.param("id"); - if (res.body) |_| { - handler.returnError("Bad Request", 400, res); - return; - } - if (user_id) |id_str| { - const id = std.fmt.parseInt(u32, id_str, 0) catch { - handler.returnError("Bad Request: Invalid Id", 400, res); - return; - }; - db.deleteById(models.User, id) catch |err| { - std.debug.print("Error while deleting user: {}\n", .{err}); - handler.returnError("Internal Server Error", 500, res); - return; - }; - } else { - handler.returnError("Bad Request: Missing ID", 400, res); - } - return; -} +// const user_id = req.param("id"); +// if (res.body) |_| { +// handler.returnError("Bad Request", 400, res); +// return; +// } +// if (user_id) |id_str| { +// const id = std.fmt.parseInt(u32, id_str, 0) catch { +// handler.returnError("Bad Request: Invalid Id", 400, res); +// return; +// }; +// db.deleteById(models.User, id) catch |err| { +// std.debug.print("Error while deleting user: {}\n", .{err}); +// handler.returnError("Internal Server Error", 500, res); +// return; +// }; +// } else { +// handler.returnError("Bad Request: Missing ID", 400, res); +// } +// return; +// } diff --git a/src/utils.zig b/src/utils.zig index ab16f58..4a147d4 100644 --- a/src/utils.zig +++ b/src/utils.zig @@ -95,6 +95,23 @@ pub fn removeStructFields( return result; } +pub fn generateRandomString(allocator: std.mem.Allocator) ![]const u8 { + const chars: []const u8 = "ABCDEFGHIJKJMNOPQRSTUVWXYZ1234567890"; + var xoshiro = std.rand.DefaultPrng.init(@bitCast(u64, std.time.milliTimestamp())); + const rng = xoshiro.random(); + + var code: []u8 = try allocator.alloc(u8, 5); + + for (0..code.len) |i| { + const char_index = rng.uintLessThan(u8, chars.len + 1) % chars.len; + code[i] = chars[char_index]; + } + + std.log.info("Generated family code: {s}", .{code}); + + return code; +} + test { // const vote = .{ .id = 0, .createdAt = "DATE" }; // const data = structConcatFields(vote, .{ .id2 = vote.id }); @@ -103,4 +120,10 @@ test { const user = .{ .id = 0, .createdAt = 2, .other = 3, .key = 4 }; const date = removeStructFields(user, &[_]u8{4}); std.debug.print("\n{any}\n", .{date}); + + var gpa = std.testing.allocator_instance; + var allocator = gpa.allocator(); + + const code = try generateRandomString(allocator); + std.debug.print("\nGot {s}\n", .{code}); }