Skip to content

Commit

Permalink
added rbs type support in generated code (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
SockworkOrange authored Jul 1, 2024
1 parent bfb32fd commit 538318f
Show file tree
Hide file tree
Showing 16 changed files with 654 additions and 112 deletions.
79 changes: 51 additions & 28 deletions CodeGenerator/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public Task<GenerateResponse> Generate(GenerateRequest generateRequest)
DbDriver = InstantiateDriver();
var fileQueries = GetFileQueries();
var files = fileQueries
.Select(fq => GenerateFile(fq.Value, fq.Key))
.SelectMany(fq => GenerateFiles(fq.Value, fq.Key))
.AppendIfNotNull(GenerateGemfile());
return Task.FromResult(new GenerateResponse { Files = { files } });

Expand All @@ -67,20 +67,38 @@ string QueryFilenameToClassName(string filenameWithExtension)
}
}

private File GenerateFile(IList<Query> queries, string className)
private IEnumerable<File> GenerateFiles(IList<Query> queries, string className)
{
var (requiredGems, moduleDeclaration) = GenerateModule(queries, className);
var contents = $"""
{AutoGeneratedComment}
{requiredGems.Select(r => r.Build()).JoinByNewLine()}
{moduleDeclaration.Build()}
""";
return new File
IEnumerable<File> files = new List<File>
{
Name = $"{className.SnakeCase()}.rb",
Contents = ByteString.CopyFromUtf8(contents)
new()
{
Name = $"{className.SnakeCase()}.rb",
Contents = ByteString.CopyFromUtf8(
$"""
{AutoGeneratedComment}
{requiredGems.Select(r => r.Build()).JoinByNewLine()}
{moduleDeclaration.BuildCode()}
"""
)
}
};
if (!Options.GenerateTypes)
return files;

files = files.Append(new File
{
Name = $"{className.SnakeCase()}.rbs",
Contents = ByteString.CopyFromUtf8(
$"""
{AutoGeneratedComment}
{moduleDeclaration.BuildType()}
"""
)
});
return files;
}

private File? GenerateGemfile()
Expand All @@ -91,15 +109,18 @@ private File GenerateFile(IList<Query> queries, string className)
return new File
{
Name = "Gemfile",
Contents = ByteString.CopyFromUtf8($"""
source 'https://rubygems.org'
Contents = ByteString.CopyFromUtf8(
$"""
source 'https://rubygems.org'
{requireGems}
""")
{requireGems}
"""
)
};
}

private (IEnumerable<RequireGem>, ModuleDeclaration) GenerateModule(IList<Query> queries, string className)
private (IEnumerable<RequireGem>, ModuleDeclaration) GenerateModule(IList<Query> queries,
string className)
{
var requiredGems = DbDriver.GetRequiredGems();
var initMethod = DbDriver.GetInitMethod();
Expand Down Expand Up @@ -130,37 +151,39 @@ ClassDeclaration GetClassDeclaration()
}
}

private static SimpleStatement GenerateDataclass(string name, ClassMember classMember, IEnumerable<Column> columns,
private IComposableRbsType GenerateDataclass(string funcName, ClassMember classMember, IList<Column> columns,
Options options)
{
var dataclassName = $"{name.FirstCharToUpper()}{classMember.Name()}";
var dataColumns = columns.Select(c => $":{c.Name.ToLower()}").ToList();
var dataColumnsStr = dataColumns.JoinByCommaAndFormat();
return new SimpleStatement(dataclassName,
new SimpleExpression(options.RubyVersion.ImmutableDataSupported()
? $"Data.define({dataColumnsStr})"
: $"Struct.new({dataColumnsStr})"));
var dataclassName = $"{funcName.FirstCharToUpper()}{classMember.Name()}";
var nameToType = columns.ToDictionary(
kv => kv.Name,
kv => DbDriver.GetColumnType(kv)
);
return options.RubyVersion.ImmutableDataSupported()
? new DataDefine(dataclassName, nameToType)
: new NewStruct(dataclassName, nameToType);
}

private SimpleStatement? GetQueryColumnsDataclass(Query query)
private IComposableRbsType? GetQueryColumnsDataclass(Query query)
{
return query.Columns.Count <= 0
? null
: GenerateDataclass(query.Name, ClassMember.Row, query.Columns, Options);
}

private SimpleStatement? GetQueryParamsDataclass(Query query)

private IComposableRbsType? GetQueryParamsDataclass(Query query)
{
if (query.Params.Count <= 0)
return null;
var columns = query.Params.Select(p => p.Column);
var columns = query.Params.Select(p => p.Column).ToList();
return GenerateDataclass(query.Name, ClassMember.Args, columns, Options);
}

private MethodDeclaration GetMethodDeclaration(Query query)
{
var queryTextConstant = GetInterfaceName(ClassMember.Sql);
var argInterface = GetInterfaceName(ClassMember.Args).SnakeCase();
var argInterface = GetInterfaceName(ClassMember.Args);
var returnInterface = GetInterfaceName(ClassMember.Row);
var funcName = query.Name.SnakeCase();

Expand Down
16 changes: 15 additions & 1 deletion Drivers/DbDriver.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Plugin;
using RubyCodegen;
using System;
using System.Collections.Generic;

namespace SqlcGenRuby.Drivers;
Expand All @@ -15,7 +16,20 @@ protected static IEnumerable<RequireGem> GetCommonGems()

public abstract MethodDeclaration GetInitMethod();

public abstract SimpleStatement QueryTextConstantDeclare(Query query);
protected abstract List<(string, HashSet<string>)> GetColumnMapping();

public string GetColumnType(Column column)
{
var columnType = column.Type.Name.ToLower();
foreach (var (csharpType, dbTypes) in GetColumnMapping())
{
if (dbTypes.Contains(columnType))
return csharpType;
}
throw new NotSupportedException($"Unsupported column type: {column.Type.Name}");
}

public abstract PropertyDeclaration QueryTextConstantDeclare(Query query);

public abstract IComposable PrepareStmt(string funcName, string queryTextConstant);

Expand Down
24 changes: 19 additions & 5 deletions Drivers/MethodGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ public MethodDeclaration OneDeclare(string funcName, string queryTextConstant, s
]
).ToList();

return new MethodDeclaration(funcName, GetMethodArgs(argInterface, parameters),
return new MethodDeclaration(
funcName,
argInterface,
GetMethodArgs(argInterface, parameters),
returnInterface,
new List<IComposable>
{
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(), withResourceBody.ToList())
Expand Down Expand Up @@ -58,7 +62,11 @@ public MethodDeclaration ManyDeclare(string funcName, string queryTextConstant,
]
);

return new MethodDeclaration(funcName, GetMethodArgs(argInterface, parameters),
return new MethodDeclaration(
funcName,
argInterface,
GetMethodArgs(argInterface, parameters),
returnInterface,
new List<IComposable>
{
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(), withResourceBody.ToList())
Expand All @@ -76,7 +84,10 @@ public MethodDeclaration ExecDeclare(string funcName, string queryTextConstant,
.Append(dbDriver.ExecuteStmt(funcName, queryParams))
.ToList();

return new MethodDeclaration(funcName, GetMethodArgs(argInterface, parameters),
return new MethodDeclaration(funcName,
argInterface,
GetMethodArgs(argInterface, parameters),
null,
new List<IComposable>
{
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(), withResourceBody.ToList()
Expand All @@ -100,7 +111,10 @@ public MethodDeclaration ExecLastIdDeclare(string funcName, string queryTextCons
);

return new MethodDeclaration(
funcName, GetMethodArgs(argInterface, parameters),
funcName,
argInterface,
GetMethodArgs(argInterface, parameters),
"Integer",
new List<IComposable>
{
new WithResource(Variable.Pool.AsProperty(), Variable.Client.AsVar(),
Expand All @@ -111,7 +125,7 @@ public MethodDeclaration ExecLastIdDeclare(string funcName, string queryTextCons

private static SimpleStatement? GetQueryParams(string argInterface, IList<Parameter> parameters)
{
var queryParams = parameters.Select(p => $"{argInterface}.{p.Column.Name}").ToList();
var queryParams = parameters.Select(p => $"{argInterface.SnakeCase()}.{p.Column.Name}").ToList();
return queryParams.Count == 0
? null
: new SimpleStatement(Variable.QueryParams.AsVar(),
Expand Down
58 changes: 53 additions & 5 deletions Drivers/Mysql2Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,65 @@ public override IEnumerable<RequireGem> GetRequiredGems()

public override MethodDeclaration GetInitMethod()
{
return new MethodDeclaration("initialize", "connection_pool_params, mysql2_params",
var connectionPoolInit = new NewObject("ConnectionPool",
new[] { new SimpleExpression("**connection_pool_params") },
new List<IComposable> { new SimpleExpression("Mysql2::Client.new(**mysql2_params)") });
return new MethodDeclaration(
"initialize",
"Hash[String, String], Hash[String, String]",
"connection_pool_params, mysql2_params",
null,
[
new SimpleStatement(Variable.Pool.AsProperty(), new SimpleExpression(
"ConnectionPool::new(**connection_pool_params) { Mysql2::Client.new(**mysql2_params) }"))
new PropertyDeclaration(Variable.Pool.AsProperty(), "untyped", connectionPoolInit)
]
);
}

public override SimpleStatement QueryTextConstantDeclare(Query query)
protected override List<(string, HashSet<string>)> GetColumnMapping()
{
return new SimpleStatement($"{query.Name}{ClassMember.Sql}", new SimpleExpression($"%q({query.Text})"));
return
[
("Array[Integer]", [
"binary",
"bit",
"blob",
"longblob",
"mediumblob",
"tinyblob",
"varbinary"
]),
("String", [
"char",
"date",
"datetime",
"decimal",
"longtext",
"mediumtext",
"text",
"time",
"timestamp",
"tinytext",
"varchar",
"json"
]),
("Integer", [
"bigint",
"int",
"mediumint",
"smallint",
"tinyint",
"year"
]),
("Float", ["double", "float"]),
];
}

public override PropertyDeclaration QueryTextConstantDeclare(Query query)
{
return new PropertyDeclaration(
$"{query.Name}{ClassMember.Sql}",
"String",
new SimpleExpression($"%q({query.Text})"));
}

public override IComposable PrepareStmt(string _, string queryTextConstant)
Expand Down
73 changes: 63 additions & 10 deletions Drivers/PgDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ public override IEnumerable<RequireGem> GetRequiredGems()

public override MethodDeclaration GetInitMethod()
{
return new MethodDeclaration("initialize", "connection_pool_params, pg_params",
var connectionPoolInit = new NewObject("ConnectionPool",
new[] { new SimpleExpression("**connection_pool_params") },
PgClientCreate());
return new MethodDeclaration(
"initialize",
"Hash[String, String], Hash[String, String]",
"connection_pool_params, pg_params",
null,
[
new SimpleStatement(
Variable.Pool.AsProperty(),
new NewObject("ConnectionPool",
new[] { new SimpleExpression("**connection_pool_params") },
PgClientCreate())),
new SimpleStatement(Variable.PreparedStatements.AsProperty(), new SimpleExpression("Set[]"))
new PropertyDeclaration(Variable.Pool.AsProperty(), "untyped", connectionPoolInit),
new PropertyDeclaration(Variable.PreparedStatements.AsProperty(), "Set[String]", new SimpleExpression("Set[]"))
]
);

Expand All @@ -50,11 +53,61 @@ IList<IComposable> PgClientCreate()
}
}

public override SimpleStatement QueryTextConstantDeclare(Query query)
protected override List<(string, HashSet<string>)> GetColumnMapping()
{
return
[
("bool", [
"bool",
"boolean"
]),
("Array[Integer]", [
"binary",
"bit",
"bytea",
"blob",
"longblob",
"mediumblob",
"tinyblob",
"varbinary"
]),
("String", [
"char",
"date",
"datetime",
"longtext",
"mediumtext",
"text",
"bpchar",
"time",
"timestamp",
"tinytext",
"varchar",
"json"
]),
("Integer", [
"int2",
"int4",
"int8",
"serial",
"bigserial"
]),
("Float", [
"numeric",
"float4",
"float8",
"decimal"
])
];
}

public override PropertyDeclaration QueryTextConstantDeclare(Query query)
{
var counter = 1;
var transformedQueryText = BindRegexToReplace().Replace(query.Text, m => $"${counter++}");
return new SimpleStatement($"{query.Name}{ClassMember.Sql}",
var transformedQueryText = BindRegexToReplace().Replace(query.Text, _ => $"${counter++}");
return new PropertyDeclaration(
$"{query.Name}{ClassMember.Sql}",
"String",
new SimpleExpression($"%q({transformedQueryText})"));
}

Expand Down
2 changes: 1 addition & 1 deletion Extensions/ListExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static string JoinByNewLine(this IEnumerable<string> me, int cnt = 1)
public static string JoinByCommaAndFormat(this IList<string> me)
{
return me.Count < MaxElementsPerLine
? string.Join(", ", me).Indent()
? string.Join(", ", me)
: $"\n{string.Join(",\n", me).Indent()}\n";
}
}
Loading

0 comments on commit 538318f

Please # to comment.