blob: 61c26f2c85e63810326b748ca2e018a9f8002e97 [file] [log] [blame]
package mirko
// Migration support via github.com/golang-migrations/migrate for go_embed data in Bazel.
// For example usage, see go/mirko/tests/sql.
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"strconv"
"strings"
"github.com/golang-migrate/migrate/v4/source"
)
func NewMigrationsFromBazel(data map[string][]byte) (source.Driver, error) {
migrations := make(map[uint]*migration)
for k, v := range data {
parts := strings.Split(k, ".")
errInvalid := fmt.Errorf("invalid migration filename: %q", k)
if len(parts) != 3 {
return nil, errInvalid
}
if parts[2] != "sql" {
return nil, errInvalid
}
if parts[1] != "up" && parts[1] != "down" {
return nil, errInvalid
}
direction := parts[1]
nameParts := strings.SplitN(parts[0], "_", 2)
if len(nameParts) != 2 {
return nil, errInvalid
}
name := nameParts[1]
version32, err := strconv.ParseUint(nameParts[0], 10, 32)
if err != nil {
return nil, errInvalid
}
version := uint(version32)
m, ok := migrations[version]
if !ok {
migrations[version] = &migration{
version: version,
name: name,
}
m = migrations[version]
} else {
if m.name != name {
if err != nil {
return nil, fmt.Errorf("migration version %d exists under diffrent names (%q vs %q)", version, name, m.name)
}
}
}
if direction == "up" {
m.up = v
} else {
m.down = v
}
}
var first uint
for version, migration := range migrations {
if migration.up == nil {
return nil, fmt.Errorf("migration version %d has no up file", version)
}
if migration.down == nil {
return nil, fmt.Errorf("migration version %d has no down file", version)
}
if first == 0 {
first = version
}
if version < first {
first = version
}
}
if first == 0 {
return nil, fmt.Errorf("no migrations, or lowest migration version is 0")
}
return &migrationSource{
migrations: migrations,
first: first,
}, nil
}
type migrationSource struct {
migrations map[uint]*migration
first uint
}
type migration struct {
version uint
name string
up []byte
down []byte
}
func (s *migrationSource) Open(url string) (source.Driver, error) {
if url != "" {
return nil, fmt.Errorf("bazel migration source is not configure via an URL")
}
return s, nil
}
func (s *migrationSource) Close() error {
return nil
}
func (s *migrationSource) First() (uint, error) {
return s.first, nil
}
func (s *migrationSource) Prev(version uint) (uint, error) {
var prev uint
for ver, _ := range s.migrations {
if ver > prev && ver < version {
prev = ver
}
}
if prev == 0 {
return 0, os.ErrNotExist
}
return prev, nil
}
func (s *migrationSource) Next(version uint) (uint, error) {
var next uint
for ver, _ := range s.migrations {
if ver <= version {
continue
}
if next == 0 {
next = ver
}
if ver < next {
next = ver
}
}
if next <= version {
return 0, os.ErrNotExist
}
return next, nil
}
func (s *migrationSource) ReadUp(version uint) (io.ReadCloser, string, error) {
m, ok := s.migrations[version]
if !ok {
return nil, "", os.ErrNotExist
}
return ioutil.NopCloser(bytes.NewReader(m.up)), m.name, nil
}
func (s *migrationSource) ReadDown(version uint) (io.ReadCloser, string, error) {
m, ok := s.migrations[version]
if !ok {
return nil, "", os.ErrNotExist
}
return ioutil.NopCloser(bytes.NewReader(m.down)), m.name, nil
}