blob: ffb6c94761716ba2ad567e57031eae9b9c3713a3 [file] [log] [blame]
package kubenat
import (
"bufio"
"bytes"
"context"
"fmt"
"io/ioutil"
"net"
"strconv"
"strings"
"github.com/golang/glog"
)
// translationReq is a request passed to the translationWorker.
type translationReq struct {
t *Tuple4
res chan *translationResp
}
// translationResp is a response from the translationWorker, sent over the res
// channel in a translationReq.
type translationResp struct {
localIP net.IP
localPort uint16
}
// reply sends a reply to the given translationReq based on a conntrackEntry,
// sending nil if the entry is nil.
func (r *translationReq) reply(e *conntrackEntry) {
if e == nil {
r.res <- nil
return
}
localPort, err := strconv.ParseUint(e.request["sport"], 10, 16)
if err != nil {
r.res <- nil
return
}
r.res <- &translationResp{
localIP: net.ParseIP(e.request["src"]),
localPort: uint16(localPort),
}
}
// translate performs a translationReq/translationResp exchange under a context
// that can be used to time out the query.
func (r *Resolver) translate(ctx context.Context, t *Tuple4) (*translationResp, error) {
resC := make(chan *translationResp, 1)
r.translationC <- &translationReq{
t: t,
res: resC,
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case res := <-resC:
return res, nil
}
}
// conntrackEntry is an entry parsed from /proc/net/nf_conntrack. The format is
// not well documented, and the best resource I could find is:
// https://stackoverflow.com/questions/16034698/details-of-proc-net-ip-conntrack-and-proc-net-nf-conntrack
type conntrackEntry struct {
// networkProtocol is currently always "ipv4".
networkProtocol string
// transmissionProtocol is currently "tcp" or "udp".
transmissionProtocol string
invalidateTimeout int64
state string
// request key-value pairs. For NAT, these are entries relating to the
// connection as seen as the 'inside' of the NAT, eg. the pod-originated
// connection.
request map[string]string
// response key-value parirs. For NAT, these are entries relating to the
// connection as seen by the 'outside' of the NAT, eg. the internet.
response map[string]string
tags map[string]bool
}
// conntrackParseEntry parses a line from /proc/net/nf_conntrack into a conntrackEntry.
func conntrackParseEntry(line string) (*conntrackEntry, error) {
entry := conntrackEntry{
request: make(map[string]string),
response: make(map[string]string),
tags: make(map[string]bool),
}
fields := strings.Fields(line)
if len(fields) < 5 {
// This should never happen unless the file format drastically
// changed. Don't bother to parse the rest, error early, and let
// someone debug this.
return nil, fmt.Errorf("invalid field count: %v", fields)
}
switch fields[0] {
case "ipv4":
if fields[1] != "2" {
return nil, fmt.Errorf("ipv4 with proto number %q, wanted 2", fields[1])
}
// TODO(q3k): support IPv6 when we get it on prod.
default:
return nil, nil
}
entry.networkProtocol = fields[0]
rest := fields[5:]
switch fields[2] {
case "tcp":
if fields[3] != "6" {
return nil, fmt.Errorf("tcp with proto number %q, wanted 6", fields[3])
}
if len(fields) < 6 {
return nil, fmt.Errorf("tcp with missing state field")
}
entry.state = fields[5]
rest = fields[6:]
case "udp":
if fields[3] != "17" {
return nil, fmt.Errorf("udp with proto number %q, wanted 17", fields[3])
}
default:
return nil, nil
}
entry.transmissionProtocol = fields[2]
invalidateTimeout, err := strconv.ParseInt(fields[4], 10, 64)
if err != nil {
return nil, fmt.Errorf("unparseable timeout %q", fields[4])
}
entry.invalidateTimeout = invalidateTimeout
for _, el := range rest {
parts := strings.Split(el, "=")
switch len(parts) {
case 1:
// This is a tag.
tag := parts[0]
// Ensure the tag starts and ends with [] (eg. [ASSURED].
if !strings.HasPrefix(tag, "[") || !strings.HasSuffix(tag, "]") {
continue
}
// Strip [ and ].
tag = tag[1:]
tag = tag[:len(tag)-1]
if _, ok := entry.tags[tag]; ok {
return nil, fmt.Errorf("repeated tag %q", tag)
}
entry.tags[tag] = true
case 2:
// This is a k/v field.
k := parts[0]
v := parts[1]
if _, ok := entry.request[k]; ok {
if _, ok := entry.response[k]; ok {
return nil, fmt.Errorf("field %q encountered more than twice", k)
} else {
entry.response[k] = v
}
} else {
entry.request[k] = v
}
default:
return nil, fmt.Errorf("unparseable column %q", el)
}
}
return &entry, nil
}
// conntrackParse parses the contents of a /proc/net/nf_conntrack file into
// multiple entries. If the majority of the entries could not be parsed, an
// error is returned.
func conntrackParse(data []byte) ([]conntrackEntry, error) {
buf := bytes.NewBuffer(data)
scanner := bufio.NewScanner(buf)
var res []conntrackEntry
var errors []error
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
entry, err := conntrackParseEntry(line)
if err != nil {
glog.Errorf("Error while parsing %q: %v", line, err)
errors = append(errors, err)
} else if entry != nil {
res = append(res, *entry)
}
}
if len(errors) == 0 || len(errors) < len(res) {
return res, nil
} else {
return nil, fmt.Errorf("encountered too many errors during conntrack parse, check logs; first error: %w", errors[0])
}
}
// contrackIndex is an index into a list of conntrackEntries. It allows lookup
// by request/response k/v pairs.
type conntrackIndex struct {
entries []conntrackEntry
// byRequest is a map from key to value to list of indixes into entries.
byRequest map[string]map[string][]int
// byResponse is a map from key to value to list of indixes into entries.
byResponse map[string]map[string][]int
}
// buildIndex builds a conntrackIndex from a list of conntrackEntries.
func buildIndex(entries []conntrackEntry) *conntrackIndex {
ix := conntrackIndex{
entries: entries,
byRequest: make(map[string]map[string][]int),
byResponse: make(map[string]map[string][]int),
}
for i, entry := range ix.entries {
for k, v := range entry.request {
if _, ok := ix.byRequest[k]; !ok {
ix.byRequest[k] = make(map[string][]int)
}
ix.byRequest[k][v] = append(ix.byRequest[k][v], i)
}
for k, v := range entry.response {
if _, ok := ix.byResponse[k]; !ok {
ix.byResponse[k] = make(map[string][]int)
}
ix.byResponse[k][v] = append(ix.byResponse[k][v], i)
}
}
return &ix
}
// getByRequest returns conntrackEntries that match a given k/v pair in their
// request fields.
func (c *conntrackIndex) getByRequest(k, v string) []*conntrackEntry {
m, ok := c.byRequest[k]
if !ok {
return nil
}
ixes, ok := m[v]
if !ok {
return nil
}
res := make([]*conntrackEntry, len(ixes))
for i, ix := range ixes {
res[i] = &c.entries[ix]
}
return res
}
// getByResponse returns conntrackEntries that match a given k/v pair in their
// response fields.
func (c *conntrackIndex) getByResponse(k, v string) []*conntrackEntry {
m, ok := c.byResponse[k]
if !ok {
return nil
}
ixes, ok := m[v]
if !ok {
return nil
}
res := make([]*conntrackEntry, len(ixes))
for i, ix := range ixes {
res[i] = &c.entries[ix]
}
return res
}
// find returns a conntrackEntry corresponding to a TCP connection defined on
// the 'outside' of the NAT by a 4-tuple, or nil if no such connection is
// found.
func (c *conntrackIndex) find(t *Tuple4) *conntrackEntry {
// TODO(q3k): support IPv6
if t.RemoteIP.To4() == nil || t.LocalIP.To4() == nil {
return nil
}
entries := c.getByResponse("src", t.RemoteIP.String())
for _, entry := range entries {
if entry.transmissionProtocol != "tcp" {
continue
}
if entry.response["sport"] != fmt.Sprintf("%d", t.RemotePort) {
continue
}
if entry.response["dst"] != t.LocalIP.String() {
continue
}
if entry.response["dport"] != fmt.Sprintf("%d", t.LocalPort) {
continue
}
return entry
}
return nil
}
// runTranslationWorker runs the conntrack 'translation worker'. It responds to
// requests over translationC until ctx is canceled.
func (r *Resolver) runTranslationWorker(ctx context.Context) {
var ix *conntrackIndex
readConntrack := func() {
var entries []conntrackEntry
data, err := ioutil.ReadFile(r.conntrackPath)
if err != nil {
glog.Errorf("Failed to read conntrack file: %v", err)
} else {
entries, err = conntrackParse(data)
if err != nil {
glog.Errorf("failed to parse conntrack entries: %v", err)
}
}
ix = buildIndex(entries)
}
readConntrack()
for {
select {
case req := <-r.translationC:
entry := ix.find(req.t)
if entry != nil {
req.reply(entry)
} else {
readConntrack()
entry = ix.find(req.t)
if entry != nil {
req.reply(entry)
} else {
req.reply(nil)
}
}
case <-ctx.Done():
return
}
}
}