| package mirko |
| |
| // Migration support via github.com/golang-migrations/migrate for go_embed data in Bazel. |
| // For example usage, see go/mirko/tests/sql. |
| |
| import ( |
| "bytes" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "os" |
| "strconv" |
| "strings" |
| |
| "github.com/golang-migrate/migrate/v4/source" |
| ) |
| |
| func NewMigrationsFromBazel(data map[string][]byte) (source.Driver, error) { |
| migrations := make(map[uint]*migration) |
| |
| for k, v := range data { |
| parts := strings.Split(k, ".") |
| errInvalid := fmt.Errorf("invalid migration filename: %q", k) |
| |
| if len(parts) != 3 { |
| return nil, errInvalid |
| } |
| if parts[2] != "sql" { |
| return nil, errInvalid |
| } |
| if parts[1] != "up" && parts[1] != "down" { |
| return nil, errInvalid |
| } |
| direction := parts[1] |
| |
| nameParts := strings.SplitN(parts[0], "_", 2) |
| if len(nameParts) != 2 { |
| return nil, errInvalid |
| } |
| |
| name := nameParts[1] |
| |
| version32, err := strconv.ParseUint(nameParts[0], 10, 32) |
| if err != nil { |
| return nil, errInvalid |
| } |
| version := uint(version32) |
| |
| m, ok := migrations[version] |
| if !ok { |
| migrations[version] = &migration{ |
| version: version, |
| name: name, |
| } |
| m = migrations[version] |
| } else { |
| if m.name != name { |
| if err != nil { |
| return nil, fmt.Errorf("migration version %d exists under diffrent names (%q vs %q)", version, name, m.name) |
| } |
| } |
| } |
| |
| if direction == "up" { |
| m.up = v |
| } else { |
| m.down = v |
| } |
| } |
| |
| var first uint |
| for version, migration := range migrations { |
| if migration.up == nil { |
| return nil, fmt.Errorf("migration version %d has no up file", version) |
| } |
| if migration.down == nil { |
| return nil, fmt.Errorf("migration version %d has no down file", version) |
| } |
| if first == 0 { |
| first = version |
| } |
| if version < first { |
| first = version |
| } |
| } |
| |
| if first == 0 { |
| return nil, fmt.Errorf("no migrations, or lowest migration version is 0") |
| } |
| |
| return &migrationSource{ |
| migrations: migrations, |
| first: first, |
| }, nil |
| } |
| |
| type migrationSource struct { |
| migrations map[uint]*migration |
| first uint |
| } |
| |
| type migration struct { |
| version uint |
| name string |
| up []byte |
| down []byte |
| } |
| |
| func (s *migrationSource) Open(url string) (source.Driver, error) { |
| if url != "" { |
| return nil, fmt.Errorf("bazel migration source is not configure via an URL") |
| } |
| return s, nil |
| } |
| |
| func (s *migrationSource) Close() error { |
| return nil |
| } |
| |
| func (s *migrationSource) First() (uint, error) { |
| return s.first, nil |
| } |
| |
| func (s *migrationSource) Prev(version uint) (uint, error) { |
| var prev uint |
| for ver, _ := range s.migrations { |
| if ver > prev && ver < version { |
| prev = ver |
| } |
| } |
| if prev == 0 { |
| return 0, os.ErrNotExist |
| } |
| return prev, nil |
| } |
| |
| func (s *migrationSource) Next(version uint) (uint, error) { |
| var next uint |
| for ver, _ := range s.migrations { |
| if ver <= version { |
| continue |
| } |
| if next == 0 { |
| next = ver |
| } |
| if ver < next { |
| next = ver |
| } |
| } |
| if next <= version { |
| return 0, os.ErrNotExist |
| } |
| return next, nil |
| } |
| |
| func (s *migrationSource) ReadUp(version uint) (io.ReadCloser, string, error) { |
| m, ok := s.migrations[version] |
| if !ok { |
| return nil, "", os.ErrNotExist |
| } |
| |
| return ioutil.NopCloser(bytes.NewReader(m.up)), m.name, nil |
| } |
| |
| func (s *migrationSource) ReadDown(version uint) (io.ReadCloser, string, error) { |
| m, ok := s.migrations[version] |
| if !ok { |
| return nil, "", os.ErrNotExist |
| } |
| |
| return ioutil.NopCloser(bytes.NewReader(m.down)), m.name, nil |
| } |