Skip to content

Commit

Permalink
standardize list construction across project (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
SockworkOrange authored Jul 20, 2024
1 parent 9ab0c3a commit 49bc5d1
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 76 deletions.
69 changes: 33 additions & 36 deletions CodeGenerator/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ public Task<GenerateResponse> Generate(GenerateRequest generateRequest)
{
Options = new Options(generateRequest);
DbDriver = InstantiateDriver();
var fileQueries = GetFileQueries();
var files = fileQueries
var generatedFiles = GetFileQueries()
.SelectMany(fq => GenerateFiles(fq.Value, fq.Key))
.AppendIfNotNull(GenerateGemfile());
return Task.FromResult(new GenerateResponse { Files = { files } });
.AppendIf(GenerateGemfile(), Options.GenerateGemfile);
return Task.FromResult(new GenerateResponse { Files = { generatedFiles } });

Dictionary<string, Query[]> GetFileQueries()
{
Expand All @@ -71,25 +70,31 @@ string QueryFilenameToClassName(string filenameWithExtension)
private IEnumerable<File> GenerateFiles(IList<Query> queries, string className)
{
var (requiredGems, moduleDeclaration) = GenerateModule(queries, className);
IEnumerable<File> files = new List<File>
return new List<File>()
.Append(GenerateCodeFile(className, requiredGems, moduleDeclaration))
.AppendIf(GenerateTypedefFile(className, moduleDeclaration), Options.GenerateTypes);
}

private static File GenerateCodeFile(string className, IEnumerable<RequireGem> requiredGems,
ModuleDeclaration moduleDeclaration)
{
return new File
{
new()
{
Name = $"{className.SnakeCase()}.rb",
Contents = ByteString.CopyFromUtf8(
$"""
{AutoGeneratedComment}
{requiredGems.Select(r => r.Build()).JoinByNewLine()}
{moduleDeclaration.BuildCode()}
"""
)
}
Name = $"{className.SnakeCase()}.rb",
Contents = ByteString.CopyFromUtf8(
$"""
{AutoGeneratedComment}
{requiredGems.Select(r => r.Build()).JoinByNewLine()}
{moduleDeclaration.BuildCode()}
"""
)
};
if (!Options.GenerateTypes)
return files;
}

files = files.Append(new File
private static File GenerateTypedefFile(string className, ModuleDeclaration moduleDeclaration)
{
return new File
{
Name = $"{className.SnakeCase()}.rbs",
Contents = ByteString.CopyFromUtf8(
Expand All @@ -98,29 +103,25 @@ private IEnumerable<File> GenerateFiles(IList<Query> queries, string className)
{moduleDeclaration.BuildType()}
"""
)
});
return files;
};
}

private File? GenerateGemfile()
private File GenerateGemfile()
{
if (!Options.GenerateGemfile)
return null;
var requireGems = DbDriver.GetRequiredGems().Select(gem => $"gem '{gem.Name()}'").JoinByNewLine();
return new File
{
Name = "Gemfile",
Contents = ByteString.CopyFromUtf8(
$"""
source 'https://rubygems.org'
{requireGems}
{DbDriver.GetRequiredGems().Select(gem => $"gem '{gem.Name()}'").JoinByNewLine()}
"""
)
};
}

private (IEnumerable<RequireGem>, ModuleDeclaration) GenerateModule(IList<Query> queries,
private (IList<RequireGem>, ModuleDeclaration) GenerateModule(IList<Query> queries,
string className)
{
var requiredGems = DbDriver.GetRequiredGems();
Expand All @@ -132,14 +133,10 @@ ModuleDeclaration GetModuleDeclaration()
{
return new ModuleDeclaration($"{Options.DriverName.ToString()}Codegen",
queries
.SelectMany(q =>
{
IEnumerable<IComposable> members = new List<IComposable>();
members = members.Append(DbDriver.QueryTextConstantDeclare(q));
members = members.AppendIfNotNull(GetQueryColumnsDataclass(q));
members = members.AppendIfNotNull(GetQueryParamsDataclass(q));
return members;
})
.SelectMany(q => new List<IComposable>()
.Append(DbDriver.QueryTextConstantDeclare(q))
.AppendIfNotNull(GetQueryColumnsDataclass(q))
.AppendIfNotNull(GetQueryParamsDataclass(q)))
.Append(classDeclaration));
}

Expand Down
4 changes: 2 additions & 2 deletions Drivers/DbDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ namespace SqlcGenRuby.Drivers;

public abstract class DbDriver
{
protected static IEnumerable<RequireGem> GetCommonGems()
protected static IList<RequireGem> GetCommonGems()
{
return [new RequireGem("connection_pool")];
}

public abstract IEnumerable<RequireGem> GetRequiredGems();
public abstract IList<RequireGem> GetRequiredGems();

public abstract MethodDeclaration GetInitMethod();

Expand Down
56 changes: 24 additions & 32 deletions Drivers/MethodGen.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using Plugin;
using RubyCodegen;
using System;
using System.Collections.Generic;
using System.Linq;

Expand Down Expand Up @@ -59,38 +58,36 @@ public MethodDeclaration ManyDeclare(string funcName, string queryTextConstant,
string returnInterface, IList<Parameter> parameters, IList<Column> columns, bool poolingEnabled = true,
RowDataType rowDataType = RowDataType.Hash)
{
var listAppend = new ListAppend(Variable.Entities.AsVar(),
new NewObject(returnInterface, GetColumnsInitExpressions(columns, rowDataType)));
IEnumerable<IComposable> withResourceBody = new List<IComposable>();
var queryParams = GetQueryParams(argInterface, parameters);
withResourceBody = withResourceBody.AppendIfNotNull(queryParams);
withResourceBody = withResourceBody
.Concat(
[
dbDriver.PrepareStmt(funcName, queryTextConstant),
ExecuteAndAssign(funcName, queryParams),
new SimpleStatement(Variable.Entities.AsVar(), new SimpleExpression("[]")),
new ForeachLoop(
Variable.Result.AsVar(),
Variable.Row.AsVar(),
new List<IComposable> { listAppend }
),
new SimpleExpression($"return {Variable.Entities.AsVar()}")
]
);
var withResourceBody = new List<IComposable>()
.AppendIfNotNull(queryParams)
.Append(dbDriver.PrepareStmt(funcName, queryTextConstant))
.Append(ExecuteAndAssign(funcName, queryParams))
.Append(new SimpleStatement(Variable.Entities.AsVar(), new SimpleExpression("[]")))
.Append(AssignResultInForeach())
.Append(new SimpleExpression($"return {Variable.Entities.AsVar()}"));

var methodArgs = GetMethodArgs(argInterface, parameters);
var methodBody = OptionallyAddPoolUsage(poolingEnabled, withResourceBody);
return new MethodDeclaration(funcName, argInterface, methodArgs, null, methodBody);

ForeachLoop AssignResultInForeach()
{
var listAppend = new ListAppend(Variable.Entities.AsVar(),
new NewObject(returnInterface, GetColumnsInitExpressions(columns, rowDataType)));
return new ForeachLoop(
Variable.Result.AsVar(),
Variable.Row.AsVar(),
new List<IComposable> { listAppend });
}
}

public MethodDeclaration ExecDeclare(string funcName, string queryTextConstant, string argInterface,
IList<Parameter> parameters, bool poolingEnabled = true)
{
IEnumerable<IComposable> withResourceBody = new List<IComposable>();
var queryParams = GetQueryParams(argInterface, parameters);
withResourceBody = withResourceBody.AppendIfNotNull(queryParams);
withResourceBody = withResourceBody
var withResourceBody = new List<IComposable>()
.AppendIfNotNull(queryParams)
.Append(dbDriver.PrepareStmt(funcName, queryTextConstant))
.Append(dbDriver.ExecuteStmt(funcName, queryParams))
.ToList();
Expand All @@ -103,17 +100,12 @@ public MethodDeclaration ExecDeclare(string funcName, string queryTextConstant,
public MethodDeclaration ExecLastIdDeclare(string funcName, string queryTextConstant, string argInterface,
IList<Parameter> parameters)
{
IEnumerable<IComposable> withResourceBody = new List<IComposable>();
var queryParams = GetQueryParams(argInterface, parameters);
withResourceBody = withResourceBody.AppendIfNotNull(queryParams);
withResourceBody = withResourceBody
.Concat(
[
dbDriver.PrepareStmt(funcName, queryTextConstant),
dbDriver.ExecuteStmt(funcName, queryParams),
new SimpleExpression($"return {Variable.Client.AsVar()}.last_id")
]
);
var withResourceBody = new List<IComposable>()
.AppendIfNotNull(queryParams)
.Append(dbDriver.PrepareStmt(funcName, queryTextConstant))
.Append(dbDriver.ExecuteStmt(funcName, queryParams))
.Append(new SimpleExpression($"return {Variable.Client.AsVar()}.last_id"));

return new MethodDeclaration(
funcName,
Expand Down
4 changes: 2 additions & 2 deletions Drivers/Mysql2Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ public Mysql2Driver()
MethodGen = new MethodGen(this);
}

public override IEnumerable<RequireGem> GetRequiredGems()
public override IList<RequireGem> GetRequiredGems()
{
return GetCommonGems().Append(new RequireGem("mysql2"));
return GetCommonGems().Append(new RequireGem("mysql2")).ToList();
}

public override MethodDeclaration GetInitMethod()
Expand Down
5 changes: 3 additions & 2 deletions Drivers/PgDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ public PgDriver()
MethodGen = new MethodGen(this);
}

public override IEnumerable<RequireGem> GetRequiredGems()
public override IList<RequireGem> GetRequiredGems()
{
return GetCommonGems()
.Append(new RequireGem("pg"))
.Append(new RequireGem("set"));
.Append(new RequireGem("set"))
.ToList();
}

public override MethodDeclaration GetInitMethod()
Expand Down
4 changes: 2 additions & 2 deletions Drivers/Sqlite3Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ public Sqlite3Driver()
MethodGen = new MethodGen(this);
}

public override IEnumerable<RequireGem> GetRequiredGems()
public override IList<RequireGem> GetRequiredGems()
{
return GetCommonGems().Append(new RequireGem("sqlite3"));
return GetCommonGems().Append(new RequireGem("sqlite3")).ToList();
}

public override MethodDeclaration GetInitMethod()
Expand Down
5 changes: 5 additions & 0 deletions Extensions/ListExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ public static class ListExtensions
{
private const int MaxElementsPerLine = 5;

public static IEnumerable<T> AppendIf<T>(this IEnumerable<T> me, T item, bool condition)
{
return condition ? me.Append(item) : me;
}

public static IEnumerable<T> AppendIfNotNull<T>(this IEnumerable<T> me, T? item)
{
return item is not null ? me.Append(item) : me;
Expand Down

0 comments on commit 49bc5d1

Please # to comment.