20200404のGoに関する記事は8件です。

もう二度と迷わない!絶対にバグらない二分探索の実装法

この記事について

二分探索アルゴリズムの概念自体はかなり基本的なものですが、いざ実装するとなると「ここのif文やwhile文の条件には=を入れていいんだっけ?」「leftやrightを動かすときはmiddle+1だっけ?middleそのままだっけ?」などたくさんの迷いポイントがあります。
この記事では、考えられる二分探索アルゴリズム全ての実装についてユニットテストを実施して結果を載せ、そこから見えるアルゴリズムの詳しい挙動について考察します。
この記事を全て読んだときには、正しい実装はどうすればいいのかきっちり覚えるはずです。

前提条件

  • アルゴリズムの実装はGo(ver1.14)で行います。
  • 考察すっ飛ばして結果だけ知りたい方は本記事の末尾まで一気に飛ばしてください。

読者に要求する前提知識

  • 二分探索アルゴリズムの概要がわかっていること。(本記事では解説しません)

ループ条件一覧

二分探索の実装でポイントになる箇所は大きく4箇所です。

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left < right { // ループ継続条件
        middle = ((right - left) / 2) + left
        if key <= array[middle] { // 範囲右寄せ,左寄せ条件
            right = middle // 左寄せ手法
        } else {
            left = middle // 右寄せ手法
        }
    }
    return
}

それぞれのポイント箇所で考えられる実装の組みは、以下16種類です。

No. ループ継続条件 範囲左寄せ条件 範囲右寄せ条件 左寄せ手法 右寄せ手法
1 left<right key<=array[middle] array[middle]<key right = middle left = middle
2 left = middle+1
3 right = middle-1 left = middle
4 left = middle+1
5 key<array[middle] array[middle]<=key right = middle left = middle
6 left = middle+1
7 right = middle-1 left = middle
8 left = middle+1
9 left<=right key<=array[middle] array[middle]<key right = middle left = middle
10 left = middle+1
11 right = middle-1 left = middle
12 left = middle+1
13 key<array[middle] array[middle]<=key right = middle left = middle
14 left = middle+1
15 right = middle-1 left = middle
16 left = middle+1

使用するテストケース

array=[1, 2, 3, 4, 4, 4, 5, 7, 8]に対して、期待する解答は以下の通り。

name key bisect_leftの場合 bisect_rightの場合
TestOver_left -1 0 0
TestStop_even 2 1 2
TestStop_odd 3 2 3
TestStop_block 4 3 6
TestOver_right 9 9 9
TestNot_exist 6 7 7

arrayにすでにkeyが存在している場合の挙動は、

  • bisect_leftの場合は挿入箇所は既存のkeyよりも左側
  • bisect_rightの場合は挿入箇所が既存のkeyよりも右側

参考:Python標準ライブラリ:順序維持のbisect

1

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left < right {
        middle = ((right - left) / 2) + left
        if key <= array[middle] {
            right = middle
        } else {
            left = middle
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (0,0,1)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (1,1,2)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (2,2,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (8,8,9)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,6,7)で無限ループ

2

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left < right {
        middle = ((right - left) / 2) + left
        if key <= array[middle] {
            right = middle
        } else {
            left = middle + 1
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (1,0,1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (2,1,2)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (3,3,3)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (9,8,9)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (7,6,7)でstop

left, rightの値がbisect_leftとなる

3

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left < right {
        middle = ((right - left) / 2) + left
        if key <= array[middle] {
            right = middle - 1
        } else {
            left = middle
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,1,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (0,1,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (1,2,1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (2,2,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (8,8,9)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,7,6)でstop

4

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left < right {
        middle = ((right - left) / 2) + left
        if key <= array[middle] {
            right = middle - 1
        } else {
            left = middle + 1
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,1,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (0,1,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (2,2,1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (3,2,3)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (9,8,9)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,5,6)でstop

5

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left < right {
        middle = ((right - left) / 2) + left
        if key < array[middle] {
            right = middle
        } else {
            left = middle
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (1,1,2)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (2,2,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (5,5,6)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (8,8,9)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,6,7)で無限ループ

6

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left < right {
        middle = ((right - left) / 2) + left
        if key < array[middle] {
            right = middle
        } else {
            left = middle + 1
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (2,1,2)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (3,3,3)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (6,5,6)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (9,8,9)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (7,6,7)でstop

left,rightの値がbisect_rightになる

7

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left < right {
        middle = ((right - left) / 2) + left
        if key < array[middle] {
            right = middle - 1
        } else {
            left = middle
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,1,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (1,2,1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (2,2,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (4,4,5)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (8,8,9)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,7,6)でstop

8

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left < right {
        middle = ((right - left) / 2) + left
        if key < array[middle] {
            right = middle - 1
        } else {
            left = middle + 1
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,1,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (2,2,1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (3,2,3)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (6,5,6)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (9,8,9)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,5,6)でstop

9

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left <= right {
        middle = ((right - left) / 2) + left
        if key <= array[middle] {
            right = middle
        } else {
            left = middle
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,0)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (0,0,1)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (1,1,2)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (2,2,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (8,8,9)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,6,7)で無限ループ

10

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left <= right {
        middle = ((right - left) / 2) + left
        if key <= array[middle] {
            right = middle
        } else {
            left = middle + 1
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,0)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (1,1,1)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (2,2,2)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (3,3,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (9,9,9)でランタイムパニック
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (7,7,7)で無限ループ

11

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left <= right {
        middle = ((right - left) / 2) + left
        if key <= array[middle] {
            right = middle - 1
        } else {
            left = middle
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,-1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (0,0,0)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (1,1,1)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (2,2,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (8,8,9)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,6,6)で無限ループ

12

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left <= right {
        middle = ((right - left) / 2) + left
        if key <= array[middle] {
            right = middle - 1
        } else {
            left = middle + 1
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,-1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (1,0,0)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (2,2,1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (3,3,2)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (9,9,9)でランタイムパニック
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (7,6,6)でstop

13

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left <= right {
        middle = ((right - left) / 2) + left
        if key < array[middle] {
            right = middle
        } else {
            left = middle
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,0)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (1,1,2)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (2,2,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (5,5,6)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (8,8,9)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,6,7)で無限ループ

14

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left <= right {
        middle = ((right - left) / 2) + left
        if key < array[middle] {
            right = middle
        } else {
            left = middle + 1
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,0)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (2,2,2)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (3,3,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (6,6,6)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (9,9,9)でランタイムパニック
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (7,7,7)で無限ループ

15

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left <= right {
        middle = ((right - left) / 2) + left
        if key < array[middle] {
            right = middle - 1
        } else {
            left = middle
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,-1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (1,1,1)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (2,2,3)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (4,4,5)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (8,8,9)で無限ループ
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (6,6,6)で無限ループ

16

func bisect(array []int, key int) {
    var left, right, middle int
    left = 0
    right = len(array)
    for left <= right {
        middle = ((right - left) / 2) + left
        if key < array[middle] {
            right = middle - 1
        } else {
            left = middle + 1
        }
    }
    return
}

array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=-1 →(left, middle,right) = (0,0,-1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=2 →(left, middle,right) = (2,2,1)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=3 →(left, middle,right) = (3,3,2)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=4 →(left, middle,right) = (6,6,5)でstop
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=9 →(left, middle,right) = (9,9,9)でランタイムパニック
array=[1, 2, 3, 4, 4, 4, 5, 7, 8], key=6 →(left, middle,right) = (7,6,6)でstop

結果一覧

No. ループ継続条件 範囲左寄せ条件 範囲右寄せ条件 左寄せ手法 右寄せ手法 ランタイムパニック テストをpass
1 left<right key<=array[middle] array[middle]<key right = middle left = middle
2 left = middle+1
3 right = middle-1 left = middle
4 left = middle+1
5 key<array[middle] array[middle]<=key right = middle left = middle
6 left = middle+1
7 right = middle-1 left = middle
8 left = middle+1
9 left<=right key<=array[middle] array[middle]<key right = middle left = middle
10 left = middle+1
11 right = middle-1 left = middle
12 left = middle+1
13 key<array[middle] array[middle]<=key right = middle left = middle
14 left = middle+1
15 right = middle-1 left = middle
16 left = middle+1

考察

以下、それぞれの実装においての挙動を考察します。

状態遷移表

範囲右寄せ・左寄せ判定前の(left,middle,right)で状態番号を振る。
表中の色文字の意味は以下の通り。

  • 青文字 : 全てのループ継続条件で停止
  • 緑文字 : ループ継続条件left<rightなら停止
  • 紫文字 : 無限ループ
  • 赤文字 : 続行
状態番号 (left, middle,right) right = middle right = middle-1 left = middle left = middle+1 備考
0 (n,n,n) (n,n,n)→0 (n,n,n-1)→停止 (n,n,n)→0 (n+1,n,n)→停止
1 (n,n,n+1) (n,n,n)→0 (n,n,n-1)→停止 (n,n,n+1)→1 (n+1,n,n+1)→0
2 (n,n+1,n+2) (n,n+1,n+1)→1 (n,n+1,n)→0 (n+1,n+1,n+2)→1 (n+2,n+1,n+2)→0
3 (n,n+1,n+3) (n,n+1,n+1)→1 (n,n+1,n)→0 (n+1,n+1,n+3)→2 (n+2,n+1,n+3)→1
4 (n,n+2,n+4) (n,n+2,n+2)→2 (n,n+2,n+1)→1 (n+2,n+2,n+4)→2 (n+3,n+2,n+4)→1
2k (n,n+k,n+2k) (n,n+k,n+k)→k (n,n+k,n+k-1)→k-1 (n+k,n+k,n+2k)→k (n+k+1,n+k,n+2k)→k-1 k>=2
2k+1 (n,n+k,n+2k+1) (n,n+k,n+k)→k (n,n+k,n+k-1)→k-1 (n+k,n+k,n+2k+1)→k+1 (n+k+1,n+k,n+2k+1)→k k>=2

無限ループが起こる条件

無限ループが起こっているパターンとしては2種類。

1. (n,n,n+1)

そのため(n,?,n+1)という状況になったときに、

  • middle = ((right - left) / 2) + leftの部分でmiddle=left=nのまま変化しない
  • left=middleの範囲右寄せが選ばれたときにleft=nで変化しない

のまま膠着状態になる。
→範囲右寄せ手法がleft=middleとなっている1,3,5,7,9,11,13,15で発生しうる。

(n,_,n+1)の状態はどうあがいても起きうるので、「middle = ((right - left) / 2) + leftの部分でmiddle=left=nのまま変化しない」を避けることは不可能。
そのため、left=middle+1は無限ループを起こさないための必要条件となる。

2. (n,n,n)

これで無限ループになるにはループ継続条件がleft<=rightでないといけない。このときに

  • middle = ((right - left) / 2) + leftの部分でmiddle=leftのまま変化しない
  • left=middleorright=middleが選ばれた場合、leftやrightが変化しない

という状態のまま膠着状態になる。
→範囲右寄せ・左寄せで必ずleft,rightの値が変化する12,16以外、つまり9,10,11,13,14,15で起こりうる。

stopする時の(left, middle,right)の組パターン

パターンとしては5種類。

1. (n,n,n)

whileループ開始時に(n,?,n+1)であった場合、次のmiddle定義で必ず(n,n,n+1)になる。
このとき、範囲左寄せが選ばれright=middleが実行された場合、(n,n,n)の状態になる。
このとき、left=rightで停止するループ継続条件left<rightであった場合、停止する。
→おきうるのは1,2,5,6

2.(n,n-1,n)

whileループ開始時に(n-1,?,n)であった場合、次のmiddle定義で必ず(n-1,n-1,n)になる。
このとき、範囲右寄せが選ばれleft=middle+1が実行された場合、(n,n-1,n)の状態になる。
このとき、left=rightで停止するループ継続条件left<rightであった場合、停止する。
→おきうるのは2,4,6,8

whileループ開始時に(n-2,?,n)であった場合、次のmiddle定義で必ず(n-2,n-1,n)になる。
このとき、範囲右寄せが選ばれleft=middle+1が実行された場合、(n,n-1,n)の状態になる。
このとき、left=rightで停止するループ継続条件left<rightであった場合、停止する。
→おきうるのは2,4,6,8

3.(n,n+1,n)

whileループ開始時に(n,?,n+3)であった場合、次のmiddle定義で必ず(n,n+1,n+3)になる。
このとき、範囲左寄せが選ばれright=middle-1が実行された場合、(n,n+1,n)の状態になる。
このとき、left=rightで停止するループ継続条件left<rightであった場合、停止する。
→おきうるのは3,4,7,8

whileループ開始時に(n,?,n+2)であった場合、次のmiddle定義で必ず(n,n+1,n+2)になる。
このとき、範囲左寄せが選ばれright=middle-1が実行された場合、(n,n+1,n)の状態になる。
このとき、left=rightで停止するループ継続条件left<rightであった場合、停止する。
→おきうるのは3,4,7,8

4.(n,n,n-1)

whileループ開始時に(n,?,n+1)であった場合、次のmiddle定義で必ず(n,n,n+1)になる。
このとき、範囲左寄せが選ばれright=middle-1が実行された場合、(n,n,n-1)の状態になる。
どんなループ継続条件であってもこれで停止する。
→おきうるのは3,4,7,8

whileループ開始時に(n,?,n)であった場合、次のmiddle定義で必ず(n,n,n)になる。
このとき、範囲左寄せが選ばれright=middle-1が実行された場合、(n,n,n-1)の状態になる。
どんなループ継続条件であってもこれで停止する。
→おきうるのは11,12,15,16

5.(n+1,n,n)

whileループ開始時に(n,?,n)であった場合、次のmiddle定義で必ず(n,n,n)になる。
このとき、範囲右寄せが選ばれleft=middle+1が実行された場合、(n+1,n,n)の状態になる。
どんなループ継続条件であってもこれで停止する。
→おきうるのは10,12,14,16

(n,n,n+1)がstopパターンにならずに無限ループパターンになるのに、(n,n,n)が両方あることの違い

whileループ終わりで(n,n,n+1)となった場合は、while条件がleft<rightでもleft<=rightでも次のループが問題なく回る。
対して(n,n,n)の場合はwhile条件がleft<rightのとき停止する。

ランタイムパニックが起こる条件

ランタイムパニックになるのがleft=right=len(a)のときのみ
このとき、次のループのmiddle = ((right - left) / 2) + leftの部分でleft=middle=right=len(a)という状況になる。
arrayの配列のインデックス範囲は0~len(a)-1なのでこれでエラーになる。

これが起こりうるのは10,12,14,16
<理由>

  • ループ継続条件がleft<rightのときは、left=right=len(a)の時点で停止するのでmiddle=len(a)の代入が行われない
  • 状態遷移図で「→0」となっているもののうち、範囲右寄せ・左寄せでrightが減少していないのがleft=middle+1のとき

ループ継続条件に=が入るとテストが一個も通らないわけ

少なくともテストを通すためには、無限ループとランタイムパニックを避ける必要がある。

  • (n,n,n+1)の無限ループの条件に当たる→9,11,13,15
  • (n,n,n)の無限ループの条件に当たる→9,10,11,13,14,15
  • ランタイムパニックが起こる条件に当たる→10,12,14,16

left<=rightである9~16は結局全部ダメなので、正しく実装するための必要条件はループ継続条件left<right

bisect_rightでもleftでもない4と8って?

whileループの間で常に成り立つ不変条件は以下の通り。

  • 2のとき : i<left → array[i]<keyかつi>=right → key<=array[i]
  • 4のとき : i<left → array[i]<keyかつi>right → key<=array[i]
  • 6のとき : i<left → array[i]<=keyかつi>=right → key<array[i]
  • 8のとき : i<left → array[i]<=keyかつi>right → key<array[I]

これにループが停止したときのleft,rightを代入すると以下のようになる。

(left,right) (n,n) (n,n-1)
2のときに代入 i<n → array[i]<keyかつi>=n → key<=array[i] ×
4のときに代入 keyarray[n]の大小関係が決定されない i<n → array[i]<keyかつi>n-1 → key<=array[i]
6のときに代入 i<n → array[i]<=keyかつi>=n → key<array[i] ×
8のときに代入 keyarray[n]の大小関係が決定されない i<n → array[i]<=keyかつi>n-1 → key<array[i]

4,8では大小関係がわからないときがあるので、このときのleftとrightに特別な意味はない。

bisectの返り値にleftを返すべきかmiddleを返すべきかrightを返すべきか

whileループの不変条件で出てくる変数はleftとrightであるので、middleはbisect関数の返り値に選ぶべきではない。
正しいbisect関数になりうる2,6では、停止時に常にleft=rightが成り立つため、leftとrightはどちらを返り値にしてもよい。

bisect_leftとbisect_rightはどう決まるのか

2と6の実装と不変条件をみると、

  • 2のとき : array[middle]<keyleft = middle+1からi<left → array[i]<keyが導かれる
  • 6のとき : key<array[middle]right = middleからi>=right → key<array[i]が導かれる

なので、array[middle]<keyのとき(=2)がbisect_leftとなり、key<array[middle]のとき(=6)の時がbisect_rightとなる。

まとめ

二分探索の正しい実装条件は以下の通りです。

bisect_left,right共通事項

bisect_leftとrightの区別は

  • leftだと範囲左寄せがkey<=array[middle]、範囲右寄せがarray[middle]<key
  • rightだと範囲左寄せがkey<array[middle]、範囲右寄せがarray[middle]<=key

(Link→bisect_leftとbisect_rightはどう決まるのか)

  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

AtCoder Beginner Contest 161 参戦記

AtCoder Beginner Contest 161 参戦記

ABC161A - ABC Swap

3分で突破. 書くだけ. オンラインのコードテストが詰まってて時間がかかってしまった.

X, Y, Z = map(int, input().split())

X, Y = Y, X
X, Z = Z, X
print(X, Y, Z)

ABC161B - Popular Vote

4分で突破. 閾値が総投票数の 4 * M 分の一なのでまずそれを求め、その閾値を超える票数の商品がM個以上あるかを調べる. オンラインのコードテストが詰まってて時間がかかってしまった.

N, M = map(int, input().split())
A = list(map(int, input().split()))

threshold = sum(A) / (4 * M)
if len([a for a in A if a >= threshold]) >= M:
    print('Yes')
else:
    print('No')

ABC161C - Replacing Integer

6分で突破. N が K を超えていれば、まず剰余を取る. N が K 未満になった後は、適当に回せば収束するだろうと、適当に1000回して放り込んだら AC が出たので結果オーライ. 後で真面目に考え直す.

N, K = map(int, input().split())

result = N
if N > K:
    result = min(result, N % K)
for i in range(1000):
    result = min(result, abs(result - K))
print(result)

追記: N % K を x とする. x < K なので、x と K の差の絶対値は K - x となる. ところで K と K - x の差の絶対値は x である. よって、x と K - x の小さい方が答えとなる.

N, K = map(int, input().split())

x = N % K
print(min(x, K - x))

ABC161D - Lunlun Number

49分で突破. 当然数字を1づつ増やしながらのループでは TLE 必須なのでスキップを考える. ルンルン数は 12 の次が 21 となる. 13は3に問題があるが、3の桁ではなく、1の桁が2になるまでルンルン数は発生しない. 問題が発生した場合は上の桁を一つ進めて、それより下の桁を0にフラッシュしてみる. それだけだと、13→20→30となってしまうので、問題のある桁の値がその一つ前の桁より小さい場合には、問題のある桁の値を一つ進めることにした. これで 13→21 と進むようになり AC. コードは is_lunlun が真偽値ではなく数値を返すやっつけなので後で直す.

K = int(input())


def is_lunlun(i):
    result = -1
    n = [ord(c) - 48 for c in str(i)]
    for j in range(len(n) - 1):
        if abs(n[j] - n[j + 1]) <= 1:
            continue
        if n[j] < n[j + 1]:
            for k in range(j + 1, len(n)):
                n[k] = 0
            result = int(''.join(str(k) for k in n))
            result += 10 ** (len(n) - (j + 1))
        else:
            result = int(''.join(str(k) for k in n))
            result += 10 ** (len(n) - (j + 2))
        break
    return result


i = 1
while True:
    # print(i)
    t = is_lunlun(i)
    if t == -1:
        K -= 1
        if K == 0:
            print(i)
            exit()
        i += 1
    else:
        i = t

ABC161E - Yutori

順位表を見るに F のほうが明らかに簡単そうなのでパスした.

ABC161F - Division or Substraction

突破できず.

追記: 30分くらい追加して解けた. トータル1時間くらい? N≤1012 なので素直にループを回すと TLE 必至. K * K > N の領域では、K = N, N - 1 を除けば N = a * K + 1 (a≥2) のパターンしか無い. 候補は、2以上 sqrt(N) 以下の K と、N - 1 が K で割り切れる時の (N - 1) / K. 候補をすべてチェックするのはたかだか 2 * 106 なので間に合う.

package main

import (
    "bufio"
    "fmt"
    "math"
    "os"
    "strconv"
)

func main() {
    N := readInt()

    if N == 2 {
        // K = 2
        fmt.Println(1)
        return
    }

    result := 2 // K = N - 1, N
    for K := 2; K <= int(math.Sqrt(float64(N))); K++ {
        t := N
        for t >= K && t%K == 0 {
            t /= K
        }
        if t%K == 1 {
            result++
        }

        if (N-1)%K == 0 && (N-1)/K > K {
            result++
        }
    }
    fmt.Println(result)
}

const (
    ioBufferSize = 1 * 1024 * 1024 // 1 MB
)

var stdinScanner = func() *bufio.Scanner {
    result := bufio.NewScanner(os.Stdin)
    result.Buffer(make([]byte, ioBufferSize), ioBufferSize)
    result.Split(bufio.ScanWords)
    return result
}()

func readString() string {
    stdinScanner.Scan()
    return stdinScanner.Text()
}

func readInt() int {
    result, err := strconv.Atoi(readString())
    if err != nil {
        panic(err)
    }
    return result
}

func readInts(n int) []int {
    result := make([]int, n)
    for i := 0; i < n; i++ {
        result[i] = readInt()
    }
    return result
}
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

GormでIN、OR、UNION ALLのクエリ書いてみた

Gormで色々とクエリ書いて見て少し理解したのでメモ

UNION ALLに関しては公式にそれらしいものがなかったので独自で書いてみた

user.go
package domain

import "time"

type User struct {
    ID         string
    Name       string
    CreatedAt  time.Time
    UpdatedAt  time.Time
}
repository.go
func (repo *Repository) FindAllByNames(names []string) ([]*domain.User, error) {
    var users []*domain.User

    // IN
    // SELECT * FROM `users`  WHERE (name IN ('taro','hanako'))
    repo.db.Debug().Find(&users, "name IN (?)", names)

    // OR
    // SELECT * FROM `users`  WHERE (name = 'taro') OR (name = 'hanako')
    query := repo.db.Debug()
    for _, v := range names {
        query = query.Or("name = ?", v)
    }
    query.Find(&users)

    // UNION ALL
    // SELECT * FROM users WHERE name = 'taro' UNION ALL SELECT * FROM users WHERE name = 'hanako'
    query = repo.db.Debug()
    var sql string = ""
    for i, v := range names {
        q := fmt.Sprintf(`SELECT * FROM users WHERE name = '%s'`, v)
        if i == 0 {
            sql += q
        } else {
            sql += " UNION ALL " + q
        }
    }
    query.Raw(sql).Find(&users)

    return users, nil
}

うん、やっぱりUNION ALL がちょっとダサい、、、

もっと良い書き方あればコメントください><

それにしてもGormの公式ドキュメントめっちゃ見やすくて、実装捗るな〜
https://gorm.io/

  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Go言語の環境構築

Go言語の環境構築

備忘録がてら

実施環境

  • macOS Catalina 10.15.3

Home brewのインストール

割愛

Goのインストール

$ brew install go
$ go version
go version go1.13.8 darwin/amd64

GoPathの設定

$ mkdir $HOME/go
$ vi ~/.bash_profile

以下を追加
------------------------------------
export GOPATH=$HOME/go
export PATH=$PATH:$HOME/go/bin
------------------------------------

$ source ~/.bash_profile

vscodeの設定

  • Goの拡張機能を追加
  • 便利ツールたちの追加
    • shift + command + P
    • Go install/Update Tools をクリック
    • 全部にチェックを入れてOK

Hello World

  • $HOME/goの下に/srcを作成
  • $HOME/go/srcの下にhogehoge(プロジェクト名)を作成

  • $HOME/go/src/hogehogeに以下ファイルを追加

main.go
package main

import "fmt"

func main() {
    fmt.Println("Hello, 世界")
}
  • $HOME/go/src/hogehogeで以下を実行
$ go run main.go 
Hello, 世界

Goのディレクトリ構成

ディレクトリ構成一例。
/bin,/pkg,/srcは基本

GOPATH
├── bin
├── pkg
└── src
    ├── hogehoge
    └── github.com

goコマンドたち

# インストールコマンド 成果物は/bin直下
$ go install {プロジェクトフォルダ}

# ビルドコマンド 成果物は同じパス
$ go build {対象ファイル}

# 環境設定一覧(GOPATHなど)
$ go env 

  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Go database/sql(コネクションプール/タイムアウト)

はじめに

本記事は、前回の続きで、database/sqlに関するメモです。

  • コネクションプール
  • クエリーのタイムアウト

DBMSにはpostgreSQLを使っています。

コネクションプール sql.DB

コネクションプールってどうやるんだっけと調べてみました。
ドキュメントを確認していると、sql.Openに該当する記述を見つけました。

The returned DB is safe for concurrent use by multiple goroutines and maintains its own pool of idle connections. Thus, the Open function should be called just once. It is rarely necessary to close a DB.

「sql.Openで得たDBはコネクションプールなので、頻繁にOpen/Closeする必要は無い。」ということのようです。
なるほどと思い、DBでコネクションプールの設定関連調べたところ、下表のメソッドが用意されていました。

メソッド名 概要
func (db *DB) SetMaxOpenConns(n int) 接続の最大数を設定。 nに0以下の値を設定で、接続数は無制限。
func (db *DB) SetMaxIdleConns(n int) コネクションプールの最大接続数を設定。
func (db *DB) SetConnMaxLifetime(d time.Duration) 接続の再利用が可能な時間を設定。dに0以下の値を設定で、ずっと再利用可能。

それぞれどのように設定するのが、良いのかと色々調べてて、以下のサイトを見つけました。

DSAS開発者の部屋 Re: Configuring sql.DB for Better Performance

こちらのサイトにとても分かりやすく解説してあります。結論の部分だけ引用させて頂きますが、「SetConnMaxLifetime を使う他の理由」の部分も必読なので、興味のある方は是非ご一読ください。

・SetMaxOpenConns() は必ず設定する。負荷が高くなってDBの応答が遅くなったとき、新規接続してさらにクエリを投げないようにするため。できれば負荷試験をして最大のスループットを発揮する最低限のコネクション数を設定するのが良いが、負荷試験をできない場合も max_connection やコア数からある程度妥当な値を判断するべき。
・SetMaxIdleConns() は SetMaxOpenConns() 以上に設定する。アイドルな接続の解放は SetConnMaxLifetime に任せる。
・SetConnMaxLifetime() は最大接続数 × 1秒 程度に設定する。多くの環境で1秒に1回接続する程度の負荷は問題にならない。1時間以上に設定したい場合はインフラ/ネットワークエンジニアによく相談すること。

というわけで、以下のように実装してみた。

func setupDB(dbDriver string, dsn string) (*sql.DB, error) {
    db, err := sql.Open(dbDriver, dsn)
    if err != nil {
        return nil, err
    }
    db.SetMaxIdleConns(10)
    db.SetMaxOpenConns(10)
    db.SetConnMaxLifetime(10 * time.Second)

    return db, err
}

SetMaxIdleConnsとSetMaxOpenConnsに設定する値の関係ですが、SetMaxOpenConnsの実装が以下のようになっていることを考えると、コネクションプールを使う場合は、同じ値を設定しておくと無難なのかなと思います。

func (db *DB) SetMaxOpenConns(n int) {
    db.mu.Lock()
    db.maxOpen = n
    if n < 0 {
        db.maxOpen = 0
    }
    syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
    db.mu.Unlock()
    if syncMaxIdle {
        db.SetMaxIdleConns(n)
    }
}

タイムアウト context.Context

contextを使うと、WithTimeoutで指定した時間を経過した場合に、クエリーの実行を中断させることが可能です。

// 10秒でタイムアウト
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)

以下は実際の呼び出し例です。
タイムアウトの設定は、タイムアウトが発生しやすいように、1ナノ秒を設定してます。

func selectUserByName(tx *sql.Tx, ctx context.Context, name string) (*User, error) {
    u := &User{}
    if err := tx.QueryRowContext(ctx, "select * from t_user where name=$1", name).Scan(&u.ID, &u.Name, &u.profile, &u.Created, &u.Updated); err != nil {
        log.Fatal(err)
    }
    return u, nil
}

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
fmt.Println(selectUserByName(tx, ctx, "taka499999"))

QueryRowContextがタイムアウトによってエラーとなると、postgreSQLの場合であれば以下のようなメッセージが表示されます。

2020/04/02 10:17:25 pq: ユーザからの要求により文をキャンセルしています

ちなみに、WithTimeoutの設定をQueryContextの結果が正しく返却されるに十分な100ミリ秒とし、その直後に1秒のスリープをいれてQueryContextを呼び出した場合、「context deadline exceeded」というメッセージと共にエラーになりました。

    ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
    time.Sleep(1 * time.Second)

期待する動作としては、QueryContextを呼び出してからWithTimeoutで設定した時間経過したらタイムアウトエラーになってほしいのですが、実際はそうではなさそうです。

タイムアウトはどこを起点としているかってことになるので、WithTimeoutを調べてみました。
結論として、起点となるのは、WithTimeoutを呼び出した日時とパラメータのtimeoutを加算した日時です。

func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
    return WithDeadline(parent, time.Now().Add(timeout))
}

ひとまずQueryContextを呼び出す直前にWithTimeoutを呼び出すようにしたとしても、クエリーごとにその実装を行うのは何か違うような気もしますし、QueryContextのエラーの原因がタイムアウト以外にあった場合に、エラーハンドリングが面倒なので、Goらしい実装ってどんなんだろっていうのが気になってます。

参考文献

この記事は以下の情報を参考にして執筆しました。

-https://golang.org/pkg/database/sql/
-DSAS開発者の部屋 Re: Configuring sql.DB for Better Performance

  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Go+Google Cloud Logging

Go の勉強がてら Google Cloud Logging を使ってみた。
以下の記事で触れられているログをリクエストログに紐づいて表示させられるようにもしてある。

gclog.go
package gclog

import (
    "context"
    "fmt"
    "os"
    "runtime"
    "strconv"
    "strings"
    "time"

    "cloud.google.com/go/logging"
    "github.com/gin-gonic/gin"
    mrpb "google.golang.org/genproto/googleapis/api/monitoredres"
    logpb "google.golang.org/genproto/googleapis/logging/v2"
)

// 以下の三つは Context に入れるべきだが簡略化のためここに記載
var WhenBegin time.Time
var TraceID string
var SpanID string

func NewClient(ctx context.Context) (*logging.Client, *logging.Logger, error) {
    projectID := "your project here." // FIXME
    client, err := logging.NewClient(ctx, projectID)

    if err != nil {
        return client, nil, err
    }

    return client, client.Logger(projectID), err
}

func Log(ctx context.Context, logger *logging.Logger, severity logging.Severity, msg string) {
    entry := logging.Entry{
        Payload:  msg,
        Severity: severity,
        Resource: &mrpb.MonitoredResource{
            Type: "gae_app",
        },
    }

    if TraceID != "" {
        entry.Trace = TraceID
        entry.SpanID = strconv.FormatInt(time.Now().UnixNano(), 10)[:16]
    }

    pc, file, line, ok := runtime.Caller(1)

    if ok {
        entry.SourceLocation = &logpb.LogEntrySourceLocation{
            File:     file,
            Line:     int64(line),
            Function: runtime.FuncForPC(pc).Name(),
        }
    }

    Logger.Log(entry)
}

func LogRequest(ctx *gin.Context, logger *logging.Logger, severity logging.Severity) {
    entry := logging.Entry{
        Severity: severity,
        Resource: &mrpb.MonitoredResource{
            Type: "gae_app",
        },
        HTTPRequest: &logging.HTTPRequest{
            Request: ctx.Request,
            Latency: time.Now().Sub(WhenBegin),
        },
    }

    if TraceID != "" {
        entry.Trace = TraceID
        entry.SpanID = SpanID
    }

    pc, file, line, ok := runtime.Caller(1)

    if ok {
        entry.SourceLocation = &logpb.LogEntrySourceLocation{
            File:     file,
            Line:     int64(line),
            Function: runtime.FuncForPC(pc).Name(),
        }
    }

    logger.Log(entry)
}

func Begin(ctx *gin.Context) {
    WhenBegin = time.Now()
    split := strings.Split(ctx.GetHeader("X-Cloud-Trace-Context"), "/")
    if len(split) >= 2 {
        TraceID = split[0]
        split = strings.Split(split[1], ";")
        SpanID = split[0]
    }
}

func End(ctx *gin.Context) {
    LogRequest(ctx, logging.Info)
}
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Go+Google Cloud Firestore

Go の勉強がてら Google Cloud Datastore へ読み書きする簡単なプログラムを作成したため共有する。
Go は Java と比べてリフレクション周りが癖があって扱いにくかった。
あと、オーバーロードがないのが辛い。

fb.go
package fb

import (
    "context"
    "fmt"
    "reflect"

    "cloud.google.com/go/firestore"
    firebase "firebase.google.com/go"
    "google.golang.org/api/iterator"
)

type Document interface {
    ID() string
    SetID(id string)
    Collection() string
}

func NewApp(ctx context.Context) (*firebase.App, error) {
    return NewAppWithProjectID(ctx, "your project id here.") // FIXME
}

func NewAppWithProjectID(ctx context.Context, projectID string) (*firebase.App, error) {
    conf := &firebase.Config{ProjectID: projectID}
    app, err := firebase.NewApp(ctx, conf)
    if err == nil {
        return app, err
    }

    return nil, err
}

func NewClient(ctx context.Context) (*firestore.Client, error) {
    return NewClientWithProjectID(ctx, "your project id here.") // FIXME
}

func NewClientWithProjectID(ctx context.Context, projectID string) (*firestore.Client, error) {
    app, err := NewAppWithProjectID(ctx, projectID)
    if err != nil {
        return nil, err
    }

    client, err := app.Firestore(ctx)
    if err == nil {
        return client, err
    }

    return nil, err
}

func GetDoc(ctx context.Context, client *firestore.Client, doc Document) error {
    var dsnap *firestore.DocumentSnapshot
    var err error

    if client == nil {
        dsnap, err = Client.Collection(doc.Collection()).Doc(doc.ID()).Get(ctx)
    } else {
        dsnap, err = client.Collection(doc.Collection()).Doc(doc.ID()).Get(ctx)
    }

    if err != nil {
        return err
    }

    err = dsnap.DataTo(doc)
    if err != nil {
        return err
    }

    return nil
}

func GetDocs(ctx context.Context, q firestore.Query, dst interface{}) error {
    var elemIsPtr bool
    slice := reflect.ValueOf(dst).Elem() // dv はスライスになる
    elemType := slice.Type().Elem()      // elemType はスライス slice の要素の型になる

    switch elemType.Kind() {
    case reflect.Struct:
        elemIsPtr = false
    case reflect.Ptr:
        elemIsPtr = true
        elemType = elemType.Elem()
    default:
        return fmt.Errorf("unsupported slice element type: %v", elemType)
    }

    iter := q.Documents(ctx)

    for {
        doc, err := iter.Next()
        if err == iterator.Done {
            break
        }
        if err != nil {
            return err
        }

        entity := reflect.New(elemType)
        entity.Interface().(Document).SetID(doc.Ref.ID)
        err = doc.DataTo(entity.Interface())
        if err != nil {
            return err
        }

        if elemIsPtr {
            slice.Set(reflect.Append(slice, entity))
        } else {
            slice.Set(reflect.Append(slice, entity.Elem()))
        }
    }

    return nil
}

func PutDoc(ctx context.Context, client *firestore.Client, doc Document) error {
    var ref *firestore.DocumentRef
    var err error

    if doc.ID() == "" {
        if client == nil {
            ref, _, err = Client.Collection(doc.Collection()).Add(ctx, doc)
        } else {
            ref, _, err = client.Collection(doc.Collection()).Add(ctx, doc)
        }
    } else {
        if client == nil {
            _, err = Client.Collection(doc.Collection()).Doc(doc.ID()).Set(ctx, doc)
        } else {
            _, err = client.Collection(doc.Collection()).Doc(doc.ID()).Set(ctx, doc)
        }
    }

    if err != nil {
        return err
    }

    if ref != nil {
        doc.SetID(ref.ID)
    }

    return nil
}

func PutDocs(ctx context.Context, client *firestore.Client, entities interface{}) error {
    var elemIsPtr bool
    slice := reflect.ValueOf(entities)
    elemType := slice.Type().Elem() // elemType はスライス slice の要素の型になる

    switch elemType.Kind() {
    case reflect.Struct:
        elemIsPtr = false
    case reflect.Ptr:
        elemIsPtr = true
        elemType = elemType.Elem()
    default:
        return fmt.Errorf("unsupported slice element type: %v", elemType)
    }

    var ref *firestore.CollectionRef
    var batch *firestore.WriteBatch

    if client == nil {
        ref = Client.Collection(reflect.New(elemType).Interface().(Document).Collection())
        batch = Client.Batch()
    } else {
        ref = client.Collection(reflect.New(elemType).Interface().(Document).Collection())
        batch = client.Batch()
    }

    for i := 0; i < slice.Len(); i++ {
        var doc Document
        if elemIsPtr {
            doc = slice.Index(i).Interface().(Document)
        } else {
            doc = slice.Index(i).Addr().Interface().(Document)
        }

        if doc.ID() == "" {
            newDoc := ref.NewDoc()
            _ = batch.Set(newDoc, doc)
            doc.SetID(newDoc.ID)
        } else {
            _ = batch.Set(ref.Doc(doc.ID()), doc)
        }
    }

    _, err := batch.Commit(ctx)
    if err != nil {
        return err
    }

    return nil
}

func DeleteDoc(ctx context.Context, client *firestore.Client, doc Document) error {
    var err error
    if client == nil {
        _, err = Client.Collection(doc.Collection()).Doc(doc.ID()).Delete(ctx)
    } else {
        _, err = client.Collection(doc.Collection()).Doc(doc.ID()).Delete(ctx)
    }

    if err != nil {
        return err
    }

    return nil
}

func DeleteDocs(ctx context.Context, client *firestore.Client, entities interface{}) error {
    var elemIsPtr bool
    slice := reflect.ValueOf(entities)
    elemType := slice.Type().Elem() // elemType はスライス slice の要素の型になる

    switch elemType.Kind() {
    case reflect.Struct:
        elemIsPtr = false
    case reflect.Ptr:
        elemIsPtr = true
        elemType = elemType.Elem()
    default:
        return fmt.Errorf("unsupported slice element type: %v", elemType)
    }

    var ref *firestore.CollectionRef
    var batch *firestore.WriteBatch

    if client == nil {
        ref = Client.Collection(reflect.New(elemType).Interface().(Document).Collection())
        batch = Client.Batch()
    } else {
        ref = client.Collection(reflect.New(elemType).Interface().(Document).Collection())
        batch = client.Batch()
    }

    for i := 0; i < slice.Len(); i++ {
        var doc Document
        if elemIsPtr {
            doc = slice.Index(i).Interface().(Document)
        } else {
            doc = slice.Index(i).Addr().Interface().(Document)
        }
        _ = batch.Delete(ref.Doc(doc.ID()))
    }

    _, err := batch.Commit(ctx)
    if err != nil {
        return err
    }

    return nil
}
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む

Go+Google Cloud Datastore

Go の勉強がてら Google Cloud Datastore へ読み書きする簡単なプログラムを作成したため共有する。
Go は Java と比べてリフレクション周りが癖があって扱いにくかった。
あと、オーバーロードがないのが辛い。

ds.go
package ds

import (
    "context"
    "fmt"
    "reflect"

    "cloud.google.com/go/datastore"
)

type Entity interface {
    EntityKey() *datastore.Key
    SetEntityKey(key *datastore.Key)
    EntityKind() string
}

func NewClient(ctx context.Context) (*datastore.Client, error) {
    return NewClientWithProjectID(ctx, "your project id here.") // FIXME
}

func NewClientWithProjectID(ctx context.Context, projectID string) (*datastore.Client, error) {
    client, err := datastore.NewClient(ctx, projectID)

    return client, err
}

func GetEntity(ctx context.Context, entity Entity) error {
    client, err := NewClient(ctx)
    if err != nil {
        return err
    }

    defer client.Close()

    return GetEntityWithClient(ctx, client, entity)
}

func GetEntityWithClient(ctx context.Context, client *datastore.Client, entity Entity) error {
    entity.EntityKey().Kind = entity.EntityKind()

    if err := client.Get(ctx, entity.EntityKey(), entity); err != nil {
        return err
    }

    return nil
}

func GetEntities(ctx context.Context, q *datastore.Query, dst interface{}) error {
    client, err := NewClient(ctx)
    if err != nil {
        return err
    }

    defer client.Close()

    return GetEntitiesWithClient(ctx, client, q, dst)
}

func GetEntitiesWithClient(ctx context.Context, client *datastore.Client, q *datastore.Query, dst interface{}) error {
    keys, err := client.GetAll(ctx, q, dst)
    if err != nil {
        return err
    }

    slice := reflect.ValueOf(dst).Elem()

    if len(keys) != slice.Len() {
        return fmt.Errorf("failed to get entities: len(keys)=%d, slice.Len()=%d", len(keys), slice.Len())
    }

    for i, key := range keys {
        entity := slice.Index(i).Addr().Interface().(Entity)
        entity.SetEntityKey(key)
    }

    return nil
}

func PutEntity(ctx context.Context, entity Entity) error {
    client, err := NewClient(ctx)
    if err != nil {
        return err
    }

    defer client.Close()

    return PutEntityWithClient(ctx, client, entity)
}

func PutEntityWithClient(ctx context.Context, client *datastore.Client, entity Entity) error {
    entity.EntityKey().Kind = entity.EntityKind()

    key, err := client.Put(ctx, entity.EntityKey(), entity)
    if err != nil {
        return err
    }

    entity.SetEntityKey(key)

    return nil
}

func PutEntities(ctx context.Context, entities interface{}) error {
    client, err := NewClient(ctx)
    if err != nil {
        return err
    }

    defer client.Close()

    return PutEntitiesWithClient(ctx, client, entities)
}

func PutEntitiesWithClient(ctx context.Context, client *datastore.Client, entities interface{}) error {
    slice := reflect.ValueOf(entities)
    keys := make([]*datastore.Key, 0, slice.Len())
    for i := 0; i < slice.Len(); i++ {
        entity := slice.Index(i).Addr().Interface().(Entity)

        entity.EntityKey().Kind = entity.EntityKind()
        keys = append(keys, entity.EntityKey())
    }

    keys, err := client.PutMulti(ctx, keys, entities)
    if err != nil {
        return err
    }
    if len(keys) != slice.Len() {
        return fmt.Errorf("failed to put entities: len(keys)=%d, slice.Len()=%d", len(keys), slice.Len())
    }

    for i, key := range keys {
        entity := slice.Index(i).Addr().Interface().(Entity)

        entity.SetEntityKey(key)
    }

    return nil
}

func DeleteEntity(ctx context.Context, key *datastore.Key) error {
    client, err := NewClient(ctx)
    if err != nil {
        return err
    }

    defer client.Close()

    return DeleteEntityWithClient(ctx, client, key)
}

func DeleteEntityWithClient(ctx context.Context, client *datastore.Client, key *datastore.Key) error {
    if err := client.Delete(ctx, key); err != nil {
        return err
    }

    return nil
}

func DeleteEntities(ctx context.Context, keys []*datastore.Key) error {
    client, err := NewClient(ctx)
    if err != nil {
        return err
    }

    defer client.Close()

    return DeleteEntitiesWithClient(ctx, client, keys)
}

func DeleteEntitiesWithClient(ctx context.Context, client *datastore.Client, keys []*datastore.Key) error {
    if err := client.DeleteMulti(ctx, keys); err != nil {
        return err
    }

    return nil
}
  • このエントリーをはてなブックマークに追加
  • Qiitaで続きを読む