mirror of https://github.com/writeas/writefreely
A focused writing and publishing space.
https://write.with.parts
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
266 lines
5.3 KiB
266 lines
5.3 KiB
/*
|
|
* Copyright © 2019-2020 Musing Studio LLC.
|
|
*
|
|
* This file is part of WriteFreely.
|
|
*
|
|
* WriteFreely is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU Affero General Public License, included
|
|
* in the LICENSE file in this source code package.
|
|
*/
|
|
|
|
package db
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
type ColumnType int
|
|
|
|
type OptionalInt struct {
|
|
Set bool
|
|
Value int
|
|
}
|
|
|
|
type OptionalString struct {
|
|
Set bool
|
|
Value string
|
|
}
|
|
|
|
type SQLBuilder interface {
|
|
ToSQL() (string, error)
|
|
}
|
|
|
|
type Column struct {
|
|
Dialect DialectType
|
|
Name string
|
|
Nullable bool
|
|
Default OptionalString
|
|
Type ColumnType
|
|
Size OptionalInt
|
|
PrimaryKey bool
|
|
}
|
|
|
|
type CreateTableSqlBuilder struct {
|
|
Dialect DialectType
|
|
Name string
|
|
IfNotExists bool
|
|
ColumnOrder []string
|
|
Columns map[string]*Column
|
|
Constraints []string
|
|
}
|
|
|
|
const (
|
|
ColumnTypeBool ColumnType = iota
|
|
ColumnTypeSmallInt ColumnType = iota
|
|
ColumnTypeInteger ColumnType = iota
|
|
ColumnTypeChar ColumnType = iota
|
|
ColumnTypeVarChar ColumnType = iota
|
|
ColumnTypeText ColumnType = iota
|
|
ColumnTypeDateTime ColumnType = iota
|
|
)
|
|
|
|
var _ SQLBuilder = &CreateTableSqlBuilder{}
|
|
|
|
var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
|
|
var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
|
|
|
|
func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
|
|
if dialect != DialectMySQL && dialect != DialectSQLite {
|
|
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
|
|
}
|
|
switch d {
|
|
case ColumnTypeSmallInt:
|
|
{
|
|
if dialect == DialectSQLite {
|
|
return "INTEGER", nil
|
|
}
|
|
mod := ""
|
|
if size.Set {
|
|
mod = fmt.Sprintf("(%d)", size.Value)
|
|
}
|
|
return "SMALLINT" + mod, nil
|
|
}
|
|
case ColumnTypeInteger:
|
|
{
|
|
if dialect == DialectSQLite {
|
|
return "INTEGER", nil
|
|
}
|
|
mod := ""
|
|
if size.Set {
|
|
mod = fmt.Sprintf("(%d)", size.Value)
|
|
}
|
|
return "INT" + mod, nil
|
|
}
|
|
case ColumnTypeChar:
|
|
{
|
|
if dialect == DialectSQLite {
|
|
return "TEXT", nil
|
|
}
|
|
mod := ""
|
|
if size.Set {
|
|
mod = fmt.Sprintf("(%d)", size.Value)
|
|
}
|
|
return "CHAR" + mod, nil
|
|
}
|
|
case ColumnTypeVarChar:
|
|
{
|
|
if dialect == DialectSQLite {
|
|
return "TEXT", nil
|
|
}
|
|
mod := ""
|
|
if size.Set {
|
|
mod = fmt.Sprintf("(%d)", size.Value)
|
|
}
|
|
return "VARCHAR" + mod, nil
|
|
}
|
|
case ColumnTypeBool:
|
|
{
|
|
if dialect == DialectSQLite {
|
|
return "INTEGER", nil
|
|
}
|
|
return "TINYINT(1)", nil
|
|
}
|
|
case ColumnTypeDateTime:
|
|
return "DATETIME", nil
|
|
case ColumnTypeText:
|
|
return "TEXT", nil
|
|
}
|
|
return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
|
|
}
|
|
|
|
func (c *Column) SetName(name string) *Column {
|
|
c.Name = name
|
|
return c
|
|
}
|
|
|
|
func (c *Column) SetNullable(nullable bool) *Column {
|
|
c.Nullable = nullable
|
|
return c
|
|
}
|
|
|
|
func (c *Column) SetPrimaryKey(pk bool) *Column {
|
|
c.PrimaryKey = pk
|
|
return c
|
|
}
|
|
|
|
func (c *Column) SetDefault(value string) *Column {
|
|
c.Default = OptionalString{Set: true, Value: value}
|
|
return c
|
|
}
|
|
|
|
func (c *Column) SetDefaultCurrentTimestamp() *Column {
|
|
def := "NOW()"
|
|
if c.Dialect == DialectSQLite {
|
|
def = "CURRENT_TIMESTAMP"
|
|
}
|
|
c.Default = OptionalString{Set: true, Value: def}
|
|
return c
|
|
}
|
|
|
|
func (c *Column) SetType(t ColumnType) *Column {
|
|
c.Type = t
|
|
return c
|
|
}
|
|
|
|
func (c *Column) SetSize(size int) *Column {
|
|
c.Size = OptionalInt{Set: true, Value: size}
|
|
return c
|
|
}
|
|
|
|
func (c *Column) String() (string, error) {
|
|
var str strings.Builder
|
|
|
|
str.WriteString(c.Name)
|
|
|
|
str.WriteString(" ")
|
|
typeStr, err := c.Type.Format(c.Dialect, c.Size)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
str.WriteString(typeStr)
|
|
|
|
if !c.Nullable {
|
|
str.WriteString(" NOT NULL")
|
|
}
|
|
|
|
if c.Default.Set {
|
|
str.WriteString(" DEFAULT ")
|
|
val := c.Default.Value
|
|
if val == "" {
|
|
val = "''"
|
|
}
|
|
str.WriteString(val)
|
|
}
|
|
|
|
if c.PrimaryKey {
|
|
str.WriteString(" PRIMARY KEY")
|
|
}
|
|
|
|
return str.String(), nil
|
|
}
|
|
|
|
func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
|
|
if b.Columns == nil {
|
|
b.Columns = make(map[string]*Column)
|
|
}
|
|
b.Columns[column.Name] = column
|
|
b.ColumnOrder = append(b.ColumnOrder, column.Name)
|
|
return b
|
|
}
|
|
|
|
func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder {
|
|
for _, column := range columns {
|
|
if _, ok := b.Columns[column]; !ok {
|
|
// This fails silently.
|
|
return b
|
|
}
|
|
}
|
|
b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ",")))
|
|
return b
|
|
}
|
|
|
|
func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder {
|
|
b.IfNotExists = ine
|
|
return b
|
|
}
|
|
|
|
func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
|
|
var str strings.Builder
|
|
|
|
str.WriteString("CREATE TABLE ")
|
|
if b.IfNotExists {
|
|
str.WriteString("IF NOT EXISTS ")
|
|
}
|
|
str.WriteString(b.Name)
|
|
|
|
var things []string
|
|
for _, columnName := range b.ColumnOrder {
|
|
column, ok := b.Columns[columnName]
|
|
if !ok {
|
|
return "", fmt.Errorf("column not found: %s", columnName)
|
|
}
|
|
columnStr, err := column.String()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
things = append(things, columnStr)
|
|
}
|
|
for _, constraint := range b.Constraints {
|
|
things = append(things, constraint)
|
|
}
|
|
|
|
if thingLen := len(things); thingLen > 0 {
|
|
str.WriteString(" ( ")
|
|
for i, thing := range things {
|
|
str.WriteString(thing)
|
|
if i < thingLen-1 {
|
|
str.WriteString(", ")
|
|
}
|
|
}
|
|
str.WriteString(" )")
|
|
}
|
|
|
|
return str.String(), nil
|
|
}
|
|
|